├── .gitignore
├── hnswlib-jna-example
├── lib
│ ├── libhnswlib-jna-x86-64.dll
│ ├── libhnswlib-jna-x86-64.dylib
│ ├── libhnswlib-jna-x86-64.exp
│ ├── libhnswlib-jna-x86-64.lib
│ └── libhnswlib-jna-x86-64.so
├── pom.xml
└── src
│ └── main
│ └── java
│ └── com
│ └── stepstone
│ └── search
│ └── hnswlib
│ └── jna
│ └── example
│ └── App.java
├── hnswlib-jna
├── src
│ ├── main
│ │ ├── resources
│ │ │ ├── libhnswlib-jna-x86-64.so
│ │ │ ├── libhnswlib-jna-aarch64.so
│ │ │ ├── libhnswlib-jna-x86-64.dll
│ │ │ ├── libhnswlib-jna-x86-64.dylib
│ │ │ ├── libhnswlib-jna-x86-64.exp
│ │ │ └── libhnswlib-jna-x86-64.libw
│ │ └── java
│ │ │ └── com
│ │ │ └── stepstone
│ │ │ └── search
│ │ │ └── hnswlib
│ │ │ └── jna
│ │ │ ├── SpaceName.java
│ │ │ ├── exception
│ │ │ ├── OnceIndexIsClearedItCannotBeReusedException.java
│ │ │ ├── IndexAlreadyInitializedException.java
│ │ │ ├── UnableToCreateNewIndexInstanceException.java
│ │ │ ├── IndexNotInitializedException.java
│ │ │ ├── ItemCannotBeInsertedIntoTheVectorSpaceException.java
│ │ │ ├── UnexpectedNativeException.java
│ │ │ └── QueryCannotReturnResultsException.java
│ │ │ ├── QueryTuple.java
│ │ │ ├── HnswlibFactory.java
│ │ │ ├── Hnswlib.java
│ │ │ ├── ConcurrentIndex.java
│ │ │ └── Index.java
│ └── test
│ │ └── java
│ │ └── com
│ │ └── stepstone
│ │ └── search
│ │ └── hnswlib
│ │ └── jna
│ │ ├── ConcurrentIndexTest.java
│ │ ├── IndexPerformanceTest.java
│ │ ├── ConcurrentIndexPerformanceTest.java
│ │ ├── HnswlibTestUtils.java
│ │ ├── AbstractPerformanceTest.java
│ │ ├── IndexTest.java
│ │ └── AbstractIndexTest.java
└── pom.xml
├── hnswlib-jna-legacy
├── src
│ ├── test
│ │ └── java
│ │ │ └── com
│ │ │ └── stepstone
│ │ │ └── search
│ │ │ └── hnswlib
│ │ │ └── jna
│ │ │ ├── LegacyIndexTest.java
│ │ │ └── LegacyConcurrentIndexTest.java
│ └── main
│ │ └── java
│ │ └── com
│ │ └── stepstone
│ │ └── search
│ │ └── hnswlib
│ │ └── jna
│ │ └── HnswlibFactory.java
└── pom.xml
├── .travis.yml
├── hnswlib
├── visited_list_pool.h
├── hnswlib.h
├── bruteforce.h
├── space_l2.h
├── space_ip.h
└── hnswalg.h
├── README.md
├── pom.xml
├── bindings.cpp
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | **/*.iml
2 | **/.idea/
3 | **/target/
4 | **/*.log
5 |
6 |
--------------------------------------------------------------------------------
/hnswlib-jna-example/lib/libhnswlib-jna-x86-64.dll:
--------------------------------------------------------------------------------
1 | /* this is just a placeholder for the dynamic library for windows */
2 |
--------------------------------------------------------------------------------
/hnswlib-jna-example/lib/libhnswlib-jna-x86-64.dylib:
--------------------------------------------------------------------------------
1 | /* this is just a placeholder for the dynamic library for mac */
2 |
--------------------------------------------------------------------------------
/hnswlib-jna-example/lib/libhnswlib-jna-x86-64.exp:
--------------------------------------------------------------------------------
1 | /* this is just a placeholder for the dynamic library for windows */
2 |
--------------------------------------------------------------------------------
/hnswlib-jna-example/lib/libhnswlib-jna-x86-64.lib:
--------------------------------------------------------------------------------
1 | /* this is just a placeholder for the dynamic library for windows */
2 |
--------------------------------------------------------------------------------
/hnswlib-jna-example/lib/libhnswlib-jna-x86-64.so:
--------------------------------------------------------------------------------
1 | /* this is just a placeholder for the dynamic library for linux */
2 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.so
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/resources/libhnswlib-jna-aarch64.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-aarch64.so
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dll
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dylib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dylib
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.exp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.exp
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.libw:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.libw
--------------------------------------------------------------------------------
/hnswlib-jna-legacy/src/test/java/com/stepstone/search/hnswlib/jna/LegacyIndexTest.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | public class LegacyIndexTest extends IndexTest {
4 |
5 | }
6 |
--------------------------------------------------------------------------------
/hnswlib-jna-legacy/src/test/java/com/stepstone/search/hnswlib/jna/LegacyConcurrentIndexTest.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | public class LegacyConcurrentIndexTest extends ConcurrentIndexTest {
4 |
5 | }
6 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/SpaceName.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | /**
4 | * Space names available in the native implementation.
5 | */
6 | public enum SpaceName { L2, IP, COSINE /* requires normalized arrays */ }
7 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/OnceIndexIsClearedItCannotBeReusedException.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna.exception;
2 |
3 | /**
4 | * Exception thrown when operations are done in a cleared index.
5 | */
6 | public class OnceIndexIsClearedItCannotBeReusedException extends UnexpectedNativeException {
7 | }
8 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/IndexAlreadyInitializedException.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna.exception;
2 |
3 | /**
4 | * Exception raised when the method initialize() of an
5 | * index has been called more than once.
6 | */
7 | public class IndexAlreadyInitializedException extends UnexpectedNativeException {
8 | }
9 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/ConcurrentIndexTest.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | public class ConcurrentIndexTest extends AbstractIndexTest {
4 |
5 | @Override
6 | protected Index createIndexInstance(SpaceName spaceName, int dimensions) {
7 | return new ConcurrentIndex(spaceName, dimensions);
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/UnableToCreateNewIndexInstanceException.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna.exception;
2 |
3 | /**
4 | * This exception is thrown when it is not possible (for some reason) to create
5 | * a new index instance in the native side.
6 | */
7 | public class UnableToCreateNewIndexInstanceException extends UnexpectedNativeException {
8 | }
9 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/IndexPerformanceTest.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import org.junit.Ignore;
4 |
5 | @Ignore
6 | public class IndexPerformanceTest extends AbstractPerformanceTest {
7 |
8 | @Override
9 | protected Index createIndexInstance(SpaceName spaceName, int dimensions) {
10 | return new Index(spaceName, dimensions);
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/IndexNotInitializedException.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna.exception;
2 |
3 | /**
4 | * Exception thrown when the index reference is not initialized on the native side.
5 | * (the method initialize() is not called after the object instantiation)
6 | */
7 | public class IndexNotInitializedException extends UnexpectedNativeException {
8 | }
9 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/ItemCannotBeInsertedIntoTheVectorSpaceException.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna.exception;
2 |
3 | /**
4 | * Exception thrown when the max number of elements into a vector
5 | * space is reached. This value is set during the vector space initialization.
6 | */
7 | public class ItemCannotBeInsertedIntoTheVectorSpaceException extends UnexpectedNativeException {
8 | }
9 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/ConcurrentIndexPerformanceTest.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import org.junit.Ignore;
4 |
5 | @Ignore
6 | public class ConcurrentIndexPerformanceTest extends AbstractPerformanceTest {
7 |
8 | @Override
9 | protected Index createIndexInstance(SpaceName spaceName, int dimensions) {
10 | return new ConcurrentIndex(spaceName, dimensions);
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/UnexpectedNativeException.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna.exception;
2 |
3 | /**
4 | * General exception for errors that happened on the native implementation.
5 | */
6 | public class UnexpectedNativeException extends RuntimeException {
7 |
8 | public UnexpectedNativeException() {
9 | }
10 |
11 | public UnexpectedNativeException(String message) {
12 | super(message);
13 | }
14 |
15 | }
16 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/HnswlibTestUtils.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import java.util.Random;
4 |
5 | public final class HnswlibTestUtils {
6 |
7 | public static float[] getRandomFloatArray(int dimension){
8 | float[] array = new float[dimension];
9 | Random random = new Random();
10 | for (int i = 0; i < dimension; i++){
11 | array[i] = random.nextFloat();
12 | }
13 | return array;
14 | }
15 |
16 | }
17 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/QueryCannotReturnResultsException.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna.exception;
2 |
3 | /**
4 | * Exception that represents that results could not be returned by the query.
5 | */
6 | public class QueryCannotReturnResultsException extends UnexpectedNativeException {
7 |
8 | private static final String MESSAGE = "Probably ef or M is too small";
9 |
10 | public QueryCannotReturnResultsException() {
11 | super(MESSAGE);
12 | }
13 |
14 | }
15 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/QueryTuple.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | /**
4 | * Query Tuple that represents the results of a knn query.
5 | * It contains two arrays: ids and coefficients.
6 | */
7 | public class QueryTuple {
8 |
9 | int[] ids;
10 | float[] coefficients;
11 |
12 | QueryTuple (int k) {
13 | ids = new int[k];
14 | coefficients = new float[k];
15 | }
16 |
17 | public float[] getCoefficients() {
18 | return coefficients;
19 | }
20 |
21 | public int[] getIds() {
22 | return ids;
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/hnswlib-jna-example/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 | hnswlib-jna-example
8 | com.stepstone.search.hnswlib.jna.example
9 | hnswlib-jna-example
10 |
11 |
12 |
13 | Apache License, Version 2.0
14 | http://www.apache.org/licenses/LICENSE-2.0.txt
15 |
16 |
17 |
18 |
19 | com.stepstone.search.hnswlib.jna
20 | hnswlib-jna-parent
21 | 1.4.0
22 | ..
23 |
24 |
25 |
26 | UTF-8
27 | 1.8
28 | 1.8
29 |
30 |
31 |
32 |
33 | com.stepstone.search.hnswlib.jna
34 | hnswlib-jna
35 | 1.4.0
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | arch: amd64
2 | language: cpp
3 | compiler: clang
4 |
5 | jobs:
6 | include:
7 | ####################################
8 | - stage: "Unit tests on linux" #
9 | ####################################
10 | os: linux
11 | dist: bionic
12 | env:
13 | - JAVA_OPTS="-Xmx2048m -Xms512m"
14 | - MAVEN_OPTS="$JAVA_OPTS"
15 | script:
16 | - clang++ -fPIC -std=c++11 -O3 -shared bindings.cpp -I hnswlib -o l hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.so
17 | - mvn test
18 | ####################################
19 | - stage: "Unit tests on macos" #
20 | ####################################
21 | os: osx
22 | osx_image: xcode9.3
23 | script:
24 | - clang++ -std=c++11 -O3 -shared bindings.cpp -I hnswlib -o l hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dylib
25 | - mvn test
26 | ###################################################
27 | - stage: "Unit tests on macos (no compilation)" #
28 | ###################################################
29 | os: osx
30 | osx_image: xcode9.3
31 | script:
32 | - mvn test
33 | ####################################
34 | - stage: "Unit tests on windows" #
35 | ####################################
36 | os: windows
37 | before_install:
38 | - choco install jdk8 --version 8.0.211
39 | - choco install maven --version 3.6.3
40 | script:
41 | - export JAVA_HOME="/c/Program Files/Java/jdk1.8.0_211/"
42 | - clang++ -O3 -shared bindings.cpp -I hnswlib -o hnswlib-jna/src/main/resources/libhnswlib-jna.dll
43 | - /c/ProgramData/chocolatey/lib/maven/apache-maven-3.6.3/bin/mvn test
44 | #####################################################
45 | - stage: "Unit tests on windows (no compilation)" #
46 | #####################################################
47 | os: windows
48 | before_install:
49 | - choco install jdk8 --version 8.0.211
50 | - choco install maven --version 3.6.3
51 | script:
52 | - export JAVA_HOME="/c/Program Files/Java/jdk1.8.0_211/"
53 | - /c/ProgramData/chocolatey/lib/maven/apache-maven-3.6.3/bin/mvn test
54 |
--------------------------------------------------------------------------------
/hnswlib/visited_list_pool.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | namespace hnswlib {
7 | typedef unsigned short int vl_type;
8 |
9 | class VisitedList {
10 | public:
11 | vl_type curV;
12 | vl_type *mass;
13 | unsigned int numelements;
14 |
15 | VisitedList(int numelements1) {
16 | curV = -1;
17 | numelements = numelements1;
18 | mass = new vl_type[numelements];
19 | }
20 |
21 | void reset() {
22 | curV++;
23 | if (curV == 0) {
24 | memset(mass, 0, sizeof(vl_type) * numelements);
25 | curV++;
26 | }
27 | };
28 |
29 | ~VisitedList() { delete[] mass; }
30 | };
31 | ///////////////////////////////////////////////////////////
32 | //
33 | // Class for multi-threaded pool-management of VisitedLists
34 | //
35 | /////////////////////////////////////////////////////////
36 |
37 | class VisitedListPool {
38 | std::deque pool;
39 | std::mutex poolguard;
40 | int numelements;
41 |
42 | public:
43 | VisitedListPool(int initmaxpools, int numelements1) {
44 | numelements = numelements1;
45 | for (int i = 0; i < initmaxpools; i++)
46 | pool.push_front(new VisitedList(numelements));
47 | }
48 |
49 | VisitedList *getFreeVisitedList() {
50 | VisitedList *rez;
51 | {
52 | std::unique_lock lock(poolguard);
53 | if (pool.size() > 0) {
54 | rez = pool.front();
55 | pool.pop_front();
56 | } else {
57 | rez = new VisitedList(numelements);
58 | }
59 | }
60 | rez->reset();
61 | return rez;
62 | };
63 |
64 | void releaseVisitedList(VisitedList *vl) {
65 | std::unique_lock lock(poolguard);
66 | pool.push_front(vl);
67 | };
68 |
69 | ~VisitedListPool() {
70 | while (pool.size()) {
71 | VisitedList *rez = pool.front();
72 | pool.pop_front();
73 | delete rez;
74 | }
75 | };
76 | };
77 | }
78 |
79 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/AbstractPerformanceTest.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import com.stepstone.search.hnswlib.jna.exception.UnexpectedNativeException;
4 | import org.junit.BeforeClass;
5 | import org.junit.Ignore;
6 | import org.junit.Test;
7 |
8 | import java.time.Instant;
9 | import java.util.concurrent.ExecutorService;
10 | import java.util.concurrent.Executors;
11 | import java.util.concurrent.TimeUnit;
12 |
13 | import static org.junit.Assert.assertTrue;
14 |
15 | @Ignore
16 | public abstract class AbstractPerformanceTest {
17 |
18 | protected abstract Index createIndexInstance(SpaceName spaceName, int dimensions);
19 |
20 | @Test
21 | public void testPerformanceSingleThreadInsertionOf600kItems() throws UnexpectedNativeException {
22 | Index index = createIndexInstance(SpaceName.COSINE, 50);
23 | int numItems = 600_000;
24 | index.initialize(numItems);
25 | long begin = Instant.now().getEpochSecond();
26 | for (int i = 0; i < numItems; i++) {
27 | index.addItem(HnswlibTestUtils.getRandomFloatArray(50));
28 | }
29 | long end = Instant.now().getEpochSecond();
30 | assertTrue((end - begin) < 600); /* +/- 8min for 1 CPU of a MacBook Pro [Intel i5 2.4GHz] (on 20/01/2020) */
31 | index.clear();
32 | }
33 |
34 | @Test
35 | public void testPerformanceMultiThreadedInsertionOf600kItems() throws UnexpectedNativeException, InterruptedException {
36 | int cpus = Runtime.getRuntime().availableProcessors();
37 | ExecutorService executorService = Executors.newFixedThreadPool(cpus);
38 |
39 | int numItems = 600_000;
40 | Index index = createIndexInstance(SpaceName.COSINE, 50);
41 | index.initialize(numItems);
42 |
43 | Runnable addItemIndex = () -> {
44 | try {
45 | index.addItem(HnswlibTestUtils.getRandomFloatArray(50));
46 | } catch (UnexpectedNativeException e) {
47 | e.printStackTrace();
48 | }
49 | };
50 |
51 | long begin = Instant.now().getEpochSecond();
52 | for (int i = 0; i < numItems; i++) {
53 | executorService.submit(addItemIndex);
54 | }
55 | executorService.shutdown();
56 | executorService.awaitTermination(5, TimeUnit.MINUTES);
57 | long end = Instant.now().getEpochSecond();
58 | assertTrue((end - begin) < 150); /* 102s ~ running on a MacBook Pro [Intel i5 2.4GHz] (on 20/01/2020) */
59 | index.clear();
60 | }
61 |
62 | }
63 |
--------------------------------------------------------------------------------
/hnswlib/hnswlib.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #ifndef NO_MANUAL_VECTORIZATION
3 | #ifdef __SSE__
4 | #define USE_SSE
5 | #ifdef __AVX__
6 | #define USE_AVX
7 | #endif
8 | #endif
9 | #endif
10 |
11 | #if defined(USE_AVX) || defined(USE_SSE)
12 | #ifdef _MSC_VER
13 | #include
14 | #include
15 | #else
16 | #include
17 | #endif
18 |
19 | #if defined(__GNUC__)
20 | #define PORTABLE_ALIGN32 __attribute__((aligned(32)))
21 | #else
22 | #define PORTABLE_ALIGN32 __declspec(align(32))
23 | #endif
24 | #endif
25 |
26 | #include
27 | #include
28 | #include
29 | #include
30 |
31 | namespace hnswlib {
32 | typedef size_t labeltype;
33 |
34 | template
35 | class pairGreater {
36 | public:
37 | bool operator()(const T& p1, const T& p2) {
38 | return p1.first > p2.first;
39 | }
40 | };
41 |
42 | template
43 | static void writeBinaryPOD(std::ostream &out, const T &podRef) {
44 | out.write((char *) &podRef, sizeof(T));
45 | }
46 |
47 | template
48 | static void readBinaryPOD(std::istream &in, T &podRef) {
49 | in.read((char *) &podRef, sizeof(T));
50 | }
51 |
52 | template
53 | using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
54 |
55 |
56 | template
57 | class SpaceInterface {
58 | public:
59 | //virtual void search(void *);
60 | virtual size_t get_data_size() = 0;
61 |
62 | virtual DISTFUNC get_dist_func() = 0;
63 |
64 | virtual void *get_dist_func_param() = 0;
65 |
66 | virtual ~SpaceInterface() {}
67 | };
68 |
69 | template
70 | class AlgorithmInterface {
71 | public:
72 | virtual void addPoint(const void *datapoint, labeltype label)=0;
73 | virtual std::priority_queue> searchKnn(const void *, size_t) const = 0;
74 | template
75 | std::vector> searchKnn(const void*, size_t, Comp) {
76 | }
77 | virtual void saveIndex(const std::string &location)=0;
78 | virtual ~AlgorithmInterface(){
79 | }
80 | };
81 |
82 |
83 | }
84 |
85 | #include "space_l2.h"
86 | #include "space_ip.h"
87 | #include "bruteforce.h"
88 | #include "hnswalg.h"
89 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/IndexTest.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import com.stepstone.search.hnswlib.jna.exception.IndexNotInitializedException;
4 | import com.stepstone.search.hnswlib.jna.exception.OnceIndexIsClearedItCannotBeReusedException;
5 | import com.stepstone.search.hnswlib.jna.exception.UnexpectedNativeException;
6 | import org.junit.Test;
7 |
8 | import static org.hamcrest.CoreMatchers.instanceOf;
9 | import static org.junit.Assert.assertEquals;
10 | import static org.junit.Assert.assertThat;
11 |
12 | public class IndexTest extends AbstractIndexTest {
13 |
14 | @Override
15 | protected Index createIndexInstance(SpaceName spaceName, int dimensions) {
16 | return new Index(spaceName, dimensions);
17 | }
18 |
19 | @Test
20 | public void testSynchronisedIndex() throws UnexpectedNativeException {
21 | Index i1 = createIndexInstance(SpaceName.COSINE, 50);
22 | i1.initialize(500_000, 16, 200, 100);
23 | Index syncIndex = Index.synchronizedIndex(i1);
24 | assertEquals(syncIndex.getLength(), i1.getLength());
25 | assertThat(syncIndex, instanceOf(ConcurrentIndex.class));
26 | syncIndex.clear();
27 | }
28 |
29 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class)
30 | public void testSynchronisedIndexFailAfterReferenceClear() throws UnexpectedNativeException {
31 | Index i1 = createIndexInstance(SpaceName.COSINE, 50);
32 | i1.initialize(500_000, 16, 200, 100);
33 | Index syncIndex = Index.synchronizedIndex(i1);
34 | syncIndex.clear();
35 | //has to fail as i1 was cleared through syncIndex
36 | i1.addItem(HnswlibTestUtils.getRandomFloatArray(50));
37 | }
38 |
39 | @Test
40 | public void testComputeSimilarity() {
41 | Index index = createIndexInstance(SpaceName.COSINE, 2);
42 | index.initialize();
43 | float similarityClose = index.computeSimilarity(
44 | new float[]{1F, 2F},
45 | new float[]{1F, 3F}
46 | );
47 | float similarityFar = index.computeSimilarity(
48 | new float[] {1F, 100F},
49 | new float[] {50F, 450F}
50 | );
51 | // both values are minus, so the closer one should be closer to zero than the farther one
52 | assertEquals(Float.compare(similarityClose, similarityFar), 1);
53 | }
54 |
55 | @Test(expected = IndexNotInitializedException.class)
56 | public void testComputeSimilarityWhenNotInitialized() {
57 | Index index = createIndexInstance(SpaceName.COSINE, 2);
58 | index.computeSimilarity(
59 | new float[] {1F, 100F},
60 | new float[] {50F, 450F});
61 | }
62 |
63 | }
64 |
--------------------------------------------------------------------------------
/hnswlib-jna/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 | hnswlib-jna
8 | hnswlib-jna
9 | 1.4.0
10 |
11 |
12 |
13 | Apache License, Version 2.0
14 | http://www.apache.org/licenses/LICENSE-2.0.txt
15 |
16 |
17 |
18 |
19 | com.stepstone.search.hnswlib.jna
20 | hnswlib-jna-parent
21 | 1.4.0
22 | ..
23 |
24 |
25 |
26 | UTF-8
27 | 1.8
28 | 1.8
29 | 2.2
30 | 5.5.0
31 | 4.11
32 |
33 |
34 |
35 |
36 | net.java.dev.jna
37 | jna
38 | ${jna.version}
39 |
40 |
41 |
42 | it.unimi.dsi
43 | fastutil
44 | 8.5.4
45 |
46 |
47 | junit
48 | junit
49 | ${junit.version}
50 | test
51 |
52 |
53 |
54 |
55 |
56 |
57 | maven-assembly-plugin
58 |
59 |
60 | package
61 |
62 | single
63 |
64 |
65 |
66 |
67 |
68 | jar-with-dependencies
69 |
70 |
71 |
72 |
73 | org.apache.maven.plugins
74 | maven-jar-plugin
75 | ${jar.plugin.version}
76 |
77 |
78 |
79 | test-jar
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/HnswlibFactory.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import com.sun.jna.Native;
4 | import com.sun.jna.Platform;
5 |
6 | import java.io.IOException;
7 | import java.io.InputStream;
8 | import java.nio.file.Files;
9 | import java.nio.file.Path;
10 | import java.nio.file.StandardCopyOption;
11 |
12 | /**
13 | * Factory for the hnswlib JNA implementation.
14 | */
15 | final class HnswlibFactory {
16 |
17 | private static final String LIBRARY_NAME = "hnswlib-jna-" + Platform.ARCH;
18 | private static final String JNA_LIBRARY_PATH_PROPERTY = "jna.library.path";
19 |
20 | private static Hnswlib instance;
21 |
22 | private HnswlibFactory() {
23 | }
24 |
25 | /**
26 | * Return a single instance of the loaded library.
27 | *
28 | * @return hnswlib JNA instance.
29 | */
30 | static synchronized Hnswlib getInstance() {
31 | if (instance == null) {
32 | try {
33 | checkIfLibraryProvidedNeedsToBeLoadedIntoSO();
34 | instance = Native.load(LIBRARY_NAME, Hnswlib.class);
35 | } catch (UnsatisfiedLinkError | IOException | NullPointerException e) {
36 | throw new UnsatisfiedLinkError("It's not possible to use the pre-generated dynamic libraries on your system. "
37 | + "Please compile it yourself (if not done yet) and set the \"" + JNA_LIBRARY_PATH_PROPERTY + "\" property "
38 | + "with correct path to where \"" + getLibraryFileName() + "\" is located.");
39 | }
40 | }
41 | return instance;
42 | }
43 |
44 | private static String getLibraryFileName(){
45 | String extension;
46 | if (Platform.isLinux()) {
47 | extension = "so";
48 | } else if (Platform.isWindows()) {
49 | extension = "dll";
50 | } else {
51 | extension = "dylib";
52 | }
53 | return String.format("libhnswlib-jna-%s.%s", Platform.ARCH, extension);
54 | }
55 |
56 | private static void copyPreGeneratedLibraryFiles(Path folder, String fileName) throws IOException {
57 | InputStream libraryStream = HnswlibFactory.class.getResourceAsStream("/" + fileName);
58 | /* windows seems to be blocking manipulation of .lib files; we store as .libw for now. */
59 | Files.copy(libraryStream, folder.resolve(fileName.replace(".libw",".lib")), StandardCopyOption.REPLACE_EXISTING);
60 | }
61 |
62 | private static void checkIfLibraryProvidedNeedsToBeLoadedIntoSO() throws IOException {
63 | String property = System.getProperty(JNA_LIBRARY_PATH_PROPERTY);
64 | if (property == null) {
65 | Path libraryFolder = Files.createTempDirectory(LIBRARY_NAME);
66 | copyPreGeneratedLibraryFiles(libraryFolder, getLibraryFileName());
67 | if (Platform.isWindows()) {
68 | copyPreGeneratedLibraryFiles(libraryFolder, "libhnswlib-jna-x86-64.exp");
69 | copyPreGeneratedLibraryFiles(libraryFolder, "libhnswlib-jna-x86-64.libw");
70 | }
71 | System.setProperty(JNA_LIBRARY_PATH_PROPERTY, libraryFolder.toString());
72 | libraryFolder.toFile().deleteOnExit();
73 | }
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/hnswlib-jna-legacy/src/main/java/com/stepstone/search/hnswlib/jna/HnswlibFactory.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import com.sun.jna.Native;
4 | import com.sun.jna.Platform;
5 |
6 | import java.io.IOException;
7 | import java.io.InputStream;
8 | import java.nio.file.Files;
9 | import java.nio.file.Path;
10 | import java.nio.file.StandardCopyOption;
11 |
12 | /**
13 | * Factory for the hnswlib JNA (4.x.x) implementation.
14 | */
15 | public final class HnswlibFactory {
16 |
17 | private static final String LIBRARY_NAME = "hnswlib-jna-" + Platform.ARCH;
18 | private static final String JNA_LIBRARY_PATH_PROPERTY = "jna.library.path";
19 |
20 | private static Hnswlib instance;
21 |
22 | private HnswlibFactory() {
23 | }
24 |
25 | /**
26 | * Return a single instance of the loaded library.
27 | *
28 | * @return hnswlib JNA instance.
29 | */
30 | static synchronized Hnswlib getInstance() {
31 | if (instance == null) {
32 | try {
33 | checkIfLibraryProvidedNeedsToBeLoadedIntoSO();
34 | instance = (Hnswlib) Native.loadLibrary(LIBRARY_NAME, Hnswlib.class);
35 | } catch (UnsatisfiedLinkError | IOException | NullPointerException e) {
36 | throw new UnsatisfiedLinkError("It's not possible to use the pre-generated dynamic libraries on your system. "
37 | + "Please compile it yourself (if not done yet) and set the \"" + JNA_LIBRARY_PATH_PROPERTY + "\" property "
38 | + "with correct path to where \"" + getLibraryFileName() + "\" is located.");
39 | }
40 | }
41 | return instance;
42 | }
43 |
44 | private static String getLibraryFileName(){
45 | String extension;
46 | if (Platform.isLinux()) {
47 | extension = "so";
48 | } else if (Platform.isWindows()) {
49 | extension = "dll";
50 | } else {
51 | extension = "dylib";
52 | }
53 | return String.format("libhnswlib-jna-%s.%s", Platform.ARCH, extension);
54 | }
55 |
56 | private static void copyPreGeneratedLibraryFiles(Path folder, String fileName) throws IOException {
57 | InputStream libraryStream = HnswlibFactory.class.getResourceAsStream("/" + fileName);
58 | /* windows seems to be blocking manipulation of .lib files; we store as .libw for now. */
59 | Files.copy(libraryStream, folder.resolve(fileName.replace(".libw",".lib")), StandardCopyOption.REPLACE_EXISTING);
60 | }
61 |
62 | private static void checkIfLibraryProvidedNeedsToBeLoadedIntoSO() throws IOException {
63 | String property = System.getProperty(JNA_LIBRARY_PATH_PROPERTY);
64 | if (property == null) {
65 | Path libraryFolder = Files.createTempDirectory(LIBRARY_NAME);
66 | copyPreGeneratedLibraryFiles(libraryFolder, getLibraryFileName());
67 | if (Platform.isWindows()) {
68 | copyPreGeneratedLibraryFiles(libraryFolder, "libhnswlib-jna-x86-64.exp");
69 | copyPreGeneratedLibraryFiles(libraryFolder, "libhnswlib-jna-x86-64.libw");
70 | }
71 | System.setProperty(JNA_LIBRARY_PATH_PROPERTY, libraryFolder.toString());
72 | libraryFolder.toFile().deleteOnExit();
73 | }
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/hnswlib-jna-legacy/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 | hnswlib-jna-legacy
8 | hnswlib-jna-legacy
9 | 1.4.0
10 |
11 |
12 |
13 | Apache License, Version 2.0
14 | http://www.apache.org/licenses/LICENSE-2.0.txt
15 |
16 |
17 |
18 |
19 | com.stepstone.search.hnswlib.jna
20 | hnswlib-jna-parent
21 | 1.4.0
22 | ..
23 |
24 |
25 |
26 | 1.4.0
27 | UTF-8
28 | 1.8
29 | 1.8
30 | 3.1.2
31 | 4.2.2
32 | 4.11
33 |
34 |
35 |
36 |
37 | net.java.dev.jna
38 | jna
39 | ${jna.version}
40 |
41 |
42 | junit
43 | junit
44 | ${junit.version}
45 | test
46 |
47 |
48 | com.stepstone.search.hnswlib.jna
49 | hnswlib-jna
50 | ${hnswlib.jna.version}
51 | compile
52 |
53 |
54 | net.java.dev.jna
55 | jna
56 |
57 |
58 | true
59 |
60 |
61 | com.stepstone.search.hnswlib.jna
62 | hnswlib-jna
63 | ${hnswlib.jna.version}
64 | test-jar
65 | test
66 |
67 |
68 |
69 |
70 |
71 |
72 | maven-assembly-plugin
73 |
74 |
75 | package
76 |
77 | single
78 |
79 |
80 |
81 |
82 |
83 | jar-with-dependencies
84 |
85 |
86 |
87 |
88 |
89 | org.apache.maven.plugins
90 | maven-dependency-plugin
91 | ${maven.dependency.plugin.version}
92 |
93 |
94 | unpack
95 | prepare-package
96 |
97 | unpack
98 |
99 |
100 |
101 |
102 | com.stepstone.search.hnswlib.jna
103 | hnswlib-jna
104 | ${hnswlib.jna.version}
105 | jar
106 | false
107 | ${project.build.directory}/classes
108 | **/*.class,**/*.dll,**/*.dylib,,**/*.exp,**/*.lib*,**/*.so
109 | **/HnswlibFactory.class,**/*test.class
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/hnswlib-jna-example/src/main/java/com/stepstone/search/hnswlib/jna/example/App.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna.example;
2 |
3 | import com.stepstone.search.hnswlib.jna.Index;
4 | import com.stepstone.search.hnswlib.jna.QueryTuple;
5 | import com.stepstone.search.hnswlib.jna.SpaceName;
6 |
7 | import java.io.File;
8 | import java.time.Instant;
9 | import java.util.Arrays;
10 | import java.util.HashMap;
11 | import java.util.Map;
12 | import java.util.Random;
13 | import java.util.concurrent.ExecutorService;
14 | import java.util.concurrent.Executors;
15 | import java.util.concurrent.TimeUnit;
16 |
17 | public class App {
18 |
19 | private static void exampleOfACosineIndex() {
20 | float[] i1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
21 | Index.normalize(i1);
22 | float[] i2 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f};
23 | Index.normalize(i2);
24 | float[] i3 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f};
25 | Index.normalize(i3); /* For cosine, if the normalization is not explicitly done, it will be done on the native side */
26 | /* when you call index.addItem(). When explicitly done (in the Java code), use addNormalizedItem()
27 | to avoid double normalization. */
28 |
29 | Index indexCosine = new Index(SpaceName.COSINE, 7);
30 | indexCosine.initialize(3);
31 | indexCosine.addNormalizedItem(i1, 1_111_111); /* 1_111_111 is an ID */
32 | indexCosine.addNormalizedItem(i2, 2_222_222);
33 | indexCosine.addNormalizedItem(i3); /* if not defined, an incremental ID will be automatically assigned */
34 |
35 | float[] input = new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
36 | Index.normalize(input);
37 |
38 | QueryTuple cosineQT = indexCosine.knnNormalizedQuery(input, 3);
39 |
40 | System.out.println("Cosine Index - Query Results: ");
41 | System.out.println(Arrays.toString(cosineQT.getCoefficients()));
42 | System.out.println(Arrays.toString(cosineQT.getIds()));
43 | indexCosine.clear();
44 | }
45 |
46 | private static void exampleOfAInnerProductIndex() {
47 | float[] i1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
48 | float[] i2 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f};
49 | float[] i3 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f};
50 |
51 | Index indexIP = new Index(SpaceName.IP, 7);
52 | indexIP.initialize(3, 16, 100, 200); /* set maxNumberOfElements, m, efConstruction and randomSeed */
53 | indexIP.setEf(10);
54 | indexIP.addItem(i1, 1_111_111); /* 1_111_111 is an ID */
55 | indexIP.addItem(i2, 0xCAFECAFE);
56 | indexIP.addItem(i3); /* if not defined, an incremental ID will be automatically assigned */
57 |
58 | float[] input = new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
59 |
60 | QueryTuple ipQT = indexIP.knnQuery(input, 3);
61 |
62 | System.out.println("Inner Product Index - Query Results: ");
63 | System.out.println(Arrays.toString(ipQT.getCoefficients()));
64 | System.out.println(Arrays.toString(ipQT.getIds()));
65 | indexIP.clear();
66 | }
67 |
68 | private static void exampleOfMultiThreadedIndexBuild() throws InterruptedException {
69 | int numberOfItems = 200_000;
70 | int numberOfThreads = Runtime.getRuntime().availableProcessors(); /* try numberOfThreads = 1 to see the difference ;D */
71 |
72 | /* this step is just to have some content for indexing (if you have your vectors, you're good to go) */
73 | Map vectorsMap = new HashMap<>(numberOfItems);
74 | for (int i = 0; i < numberOfItems; i++){
75 | vectorsMap.put(i , getRandomFloatArray(40));
76 | }
77 | /* ************************************************************************************************* */
78 |
79 | Index index = new Index(SpaceName.IP, 7);
80 | index.initialize(numberOfItems);
81 |
82 | long startTime = Instant.now().getEpochSecond();
83 | ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads);
84 | for (Map.Entry entry : vectorsMap.entrySet()) {
85 | executorService.submit( () -> index.addItem(entry.getValue(), entry.getKey()) );
86 | }
87 | executorService.shutdown();
88 | executorService.awaitTermination(10, TimeUnit.MINUTES);
89 | long endTime = Instant.now().getEpochSecond();
90 |
91 | System.out.println("Multi Threaded Index Build:");
92 | System.out.println("Building time for " + index.getLength() + " items took " + (endTime - startTime) + " seconds with " + numberOfThreads + " threads");
93 | }
94 |
95 | private static float[] getRandomFloatArray(int dimension){
96 | float[] array = new float[dimension];
97 | Random random = new Random();
98 | for (int i = 0; i < dimension; i++){
99 | array[i] = random.nextFloat();
100 | }
101 | return array;
102 | }
103 |
104 | /**
105 | * This is an example of how manually specify the location of the
106 | * dynamic libraries for hnswlib-jna. This step is required when
107 | * the pre-compiled ones (provided within the jars) are not sufficient
108 | * due to operating system dependencies or version of others libraries.
109 | */
110 | private static void setupHnswlibJnaDLLManually(){
111 | File projectFolder = new File("hnswlib-jna-example/lib");
112 | System.setProperty("jna.library.path", projectFolder.getAbsolutePath());
113 | }
114 |
115 | public static void main( String[] args ) throws InterruptedException {
116 | //setupHnswlibJnaDLLManually();
117 | exampleOfACosineIndex();
118 | exampleOfAInnerProductIndex();
119 | exampleOfMultiThreadedIndexBuild();
120 | }
121 |
122 | }
123 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/Hnswlib.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import com.sun.jna.Library;
4 | import com.sun.jna.Pointer;
5 |
6 | /**
7 | * Interface that implements JNA (Java Native Access) to Hnswlib
8 | * a fast approximate nearest neighbor search library available
9 | * on (https://github.com/nmslib/hnswlib). This implementation was created
10 | * in order to provide a high performance also in Java (not only in Python or C++).
11 | *
12 | * This implementation relies also in a dynamic library generated
13 | * from the sources available in bindings.cpp.
14 | */
15 | public interface Hnswlib extends Library {
16 |
17 | /**
18 | * Allocates memory for the index in the native context and
19 | * stores the address in a JNA Pointer variable.
20 | *
21 | * @param spaceName - use: l2, ip or cosine strings only;
22 | * @param dimension - length of the vectors used for indexation.
23 | *
24 | * @return the index reference pointer.
25 | */
26 | Pointer createNewIndex(String spaceName, int dimension);
27 |
28 | /**
29 | * Initialize the index with information needed for the indexation.
30 | *
31 | * @param index - JNA pointer reference of the index;
32 | * @param maxNumberOfElements - max number of elements in the index;
33 | * @param m - the value of M;
34 | * @param efConstruction - ef parameter;
35 | * @param randomSeed - a random seed specified by the user.
36 | *
37 | * @return a result code.
38 | */
39 | int initNewIndex(Pointer index, int maxNumberOfElements, int m, int efConstruction, int randomSeed);
40 |
41 | /**
42 | * Add an item to the index.
43 | *
44 | * @param item - array containing the input to be inserted into the index;
45 | * @param normalized - is the item normalized? if not and if required, it will be performed at the native level;
46 | * @param id - an identifier to be used for this entry;
47 | * @param index - JNA pointer reference of the index.
48 | *
49 | * @return a result code.
50 | */
51 | int addItemToIndex(float[] item, boolean normalized, int id, Pointer index);
52 |
53 | /**
54 | * Retrieve the number of elements already inserted into the index.
55 | *
56 | * @param index - JNA pointer reference of the index.
57 | *
58 | * @return number of items in the index.
59 | */
60 | int getIndexLength(Pointer index);
61 |
62 | /**
63 | * Save the content of an index into a file (using native implementation).
64 | *
65 | * @param index - JNA pointer reference of the index.
66 | * @param path - path where the index will be stored.
67 | *
68 | * @return a result code.
69 | */
70 | int saveIndexToPath(Pointer index, String path);
71 |
72 | /**
73 | * Restore the content of an index saved into a file (using native implementation).
74 | *
75 | * @param index - JNA pointer reference of the index;
76 | * @param maxNumberOfElements - max number of items to be inserted into the index;
77 | * @param path - path where the index will be stored.
78 | *
79 | * @return a result code.
80 | */
81 | int loadIndexFromPath(Pointer index, int maxNumberOfElements, String path);
82 |
83 | /**
84 | * This function invokes the knnQuery available in the hnswlib native library.
85 | *
86 | * @param index - JNA pointer reference of the index;
87 | * @param input - input used for the query;
88 | * @param normalized - is the input normalized? if not and if required, it will be performed at the native level;
89 | * @param k - dimension used for the query;
90 | * @param indices [output] retrieves the indices returned by the query;
91 | * @param coefficients [output] retrieves the coefficients returned by the query.
92 | *
93 | * @return a result code.
94 | */
95 | int knnQuery(Pointer index, float[] input, boolean normalized, int k, int[] indices, float[] coefficients);
96 |
97 | /**
98 | * Clear the index from the memory.
99 | *
100 | * @param index - JNA pointer reference of the index.
101 | *
102 | * @return a result code.
103 | */
104 | int clearIndex(Pointer index);
105 |
106 | /**
107 | * Sets the query time accuracy / speed trade-off value.
108 | *
109 | * @param index - JNA pointer reference of the index;
110 | * @param ef value.
111 | *
112 | * @return a result code.
113 | */
114 | int setEf(Pointer index, int ef);
115 |
116 | /**
117 | * Populate vector with data for given id
118 | * @param index index
119 | * @param id id
120 | * @param vector vector
121 | * @param dim dimension
122 | *
123 | * @return result code
124 | */
125 | int getData(Pointer index, int id, float[] vector, int dim);
126 |
127 | /**
128 | * Determine whether the index contains data for given id.
129 | *
130 | * @param index index
131 | * @param id id
132 | *
133 | * @return result_code
134 | */
135 | int hasId(Pointer index, int id);
136 |
137 | /**
138 | * Compute similarity between two vectors
139 | *
140 | * @param index index
141 | * @param vector1 vector1
142 | * @param vector2 vector2
143 | *
144 | * @return similarity score between vectors
145 | */
146 | float computeSimilarity(Pointer index, float[] vector1, float[] vector2);
147 |
148 | /**
149 | * Retrieves the value of M.
150 | *
151 | * @param index reference.
152 | *
153 | * @return value of M.
154 | */
155 | int getM(Pointer index);
156 |
157 | /**
158 | * Retrieves the current ef construction value.
159 | *
160 | * @param index reference.
161 | *
162 | * @return efConstruction value.
163 | */
164 | int getEfConstruction(Pointer index);
165 |
166 | /**
167 | * Retrieves the current ef value.
168 | *
169 | * @param index reference.
170 | *
171 | * @return EF value.
172 | */
173 | int getEf(Pointer index);
174 |
175 | /**
176 | * Marks an item ID as deleted.
177 | *
178 | * @param index reference;
179 | * @param id label.
180 | *
181 | * @return a result code.
182 | */
183 | int markDeleted(Pointer index, int id);
184 |
185 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # __Hnswlib with JNA (Java Native Access)__
5 |
6 | This project contains a [JNA](https://github.com/java-native-access/jna) (Java Native Access) implementation built on top of the native [Hnswlib](https://github.com/nmslib/hnswlib) (Hierarchical Navigable Small World Graph) which offers a fast approximate nearest neighbor search. It includes some modifications and simplifications in order to provide Hnswlib features with native like performance to applications written in Java. Differently from the original Python implementation, the multi-thread support is not included in the bindings itself but it can be easily implemented on the Java side. `Hnswlib-jna` works in collaboration with a __shared library__ which contains the native code. For more information, please check the sections below.
7 |
8 | ## __Dependencies__
9 |
10 | ### __Pre-Generated Shared Library__
11 |
12 | The jar file includes some pre-generated libraries for _Windows_, _Debian Linux_ and _MacOS_ (x86-64) which should allow an easy integration and abstract all complexity related to compilation. An extra library for Debian Linux (aarch64) is also available for tests with AWS Graviton 2. In the case of operating system issues, a runtime exception will be thrown and the manual compilation will be advised.
13 |
14 | __On Windows, the [Build Tools for Visual Studio 2019 (C++ build tools)](https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2019) is required__.
15 |
16 | ## __Using in Your Project__
17 |
18 | Add the following dependency in your `pom.xml`:
19 | ```
20 |
21 | com.stepstone.search.hnswlib.jna
22 | hnswlib-jna
23 | 1.4.0
24 |
25 | ```
26 |
27 | For more information and implementation details, please check [hnswlib-jna-example](./hnswlib-jna-example/).
28 |
29 | ## __Manual Compilation (Whenever it is advised)__
30 |
31 | This section includes more information about how to compile the shared libraries on Windows, Linux and Mac for different architectures (e.g., `x86-64`, `aarch64`). __If you were able to run the example project on your PC, this section can be ignored.__
32 |
33 | ### __Compiling the Shared Library__
34 |
35 | To generate the shared library required by this project, `binding.cpp` needs to be compiled using a C compiler (e.g., `clang` or `gcc`) with C++11 support, at least. The library can be generated with `clang` via:
36 | ```
37 | clang++ -O3 -shared bindings.cpp -I hnswlib -o /lib/libhnswlib-jna-x86-64.dylib
38 | ```
39 | __Note:__ The shared library's name must be: __libhnswlib-jna-ARCH.EXT__ where `ARCH` is the canonical architecture name (e.g., `x86-64` for AMD64, or `aarch64` for ARM64) and `EXT` is `dylib` for MacOS, for windows use `dll`, and linux `so`.
40 |
41 | #### Instructions for Windows
42 |
43 | ##### Using Visual Studio Build Tools
44 |
45 | 1. Download and install [LLVM](https://releases.llvm.org/9.0.0/LLVM-9.0.0-win64.exe);
46 | 2. Download and install [Build Tools for Visual Studio 2019 (or higher)](https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2019);
47 | 3. Compile the bindings using `clang`:
48 | ```
49 | clang++ -O3 -shared bindings.cpp -I hnswlib -o /lib/libhnswlib-jna-x86-64.dll
50 | ```
51 | This procedure will generate the 3 necessary files: `libhnswlib-jna-x86-64.dll`, `libhnswlib-jna-x86-64.exp` and `libhnswlib-jna-x86-64.lib`.
52 |
53 | ##### Using MinGW64
54 |
55 | 1. Download and install [LLVM](https://releases.llvm.org/9.0.0/LLVM-9.0.0-win64.exe);
56 | 2. Make sure that LLVM's bin folder is in your PATH;
57 | 3. Download [MinGW-w64 with Headers for Clang](https://sourceforge.net/projects/mingw-w64/files/Toolchains%20targetting%20Win64/Personal%20Builds/mingw-builds/8.1.0/threads-posix/seh/);
58 | 4. Unpack the archive and include MinGW64's bin folder into your PATH as well;
59 | 5. Compile the bindings using `clang`:
60 | ```
61 | clang++ -O3 -target x86_64-pc-windows-gnu -shared bindings.cpp -I hnswlib -o lib/libhnswlib-jna-x86-64.dll -lpthread
62 | ```
63 | This procedure will generate `libhnswlib-jna-x86-64.dll`.
64 |
65 | #### Instructions for Linux
66 |
67 | 1. Download and install `clang` (older versions might trigger compilation issues, so it is better use a recent version);
68 | 2. Compile the bindings using `clang`:
69 | ```
70 | clang++ -O3 -fPIC -shared -std=c++11 bindings.cpp -I hnswlib -o /lib/libhnswlib-jna-x86-64.so
71 | ```
72 | This procedure will generate `libhnswlib-jna-x86-64.so`.
73 |
74 | ### __Reading the Shared Library in Your Project__
75 |
76 | Once the shared library is available, it is necessary to tell the `JVM` and `JNA` where it is located. This can be done by setting the property `jna.library.path` via JVM parameters or system properties.
77 |
78 | #### Via JVM parameters
79 | ```
80 | -Djna.library.path=/lib
81 | ```
82 | #### Programmatically via System Class
83 | ```
84 | System.setProperty("jna.library.path", "/lib");
85 | ```
86 | For more information and implementation details, please check [hnswlib-jna-example](./hnswlib-jna-example/).
87 |
88 | ## License
89 | Copyright 2020 StepStone Services
90 |
91 | Licensed under the Apache License, Version 2.0 (the "License");
92 | you may not use this file except in compliance with the License.
93 | You may obtain a copy of the License at
94 |
95 | [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
96 |
97 | Unless required by applicable law or agreed to in writing, software
98 | distributed under the License is distributed on an "AS IS" BASIS,
99 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
100 | See the License for the specific language governing permissions and
101 | limitations under the License.
102 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 | hnswlib
8 | com.stepstone.search.hnswlib.jna
9 | hnswlib-jna-parent
10 | 1.5.0
11 | pom
12 |
13 | This project contains a JNA (Java Native Access) implementation built on top of the native Hnswlib (Hierarchical Navigable Small World Graph) which offers a fast approximate nearest neighbor search. It includes some modifications and simplifications in order to provide Hnswlib features with native like performance to applications written in Java. Differently from the original Python implementation, the multi-thread support is not included in the bindings itself but it can be easily implemented on the Java side.
14 | https://github.com/stepstone-tech/hnswlib-jna
15 |
16 |
17 | scm:git:git://github.com/stepstone-tech/hnswlib-jna.git
18 | scm:git:ssh://github.com:stepstone-tech/hnswlib-jna.git
19 | https://github.com/stepstone-tech/hnswlib-jna/tree/master
20 |
21 |
22 |
23 |
24 | Alex Docherty
25 | alexander.docherty@stepstone.com
26 | StepStone
27 | https://www.stepstone.com
28 |
29 |
30 | Casper Davies
31 | casper.davies@stepstone.com
32 | StepStone
33 | https://www.stepstone.com
34 |
35 |
36 | German Hurtado
37 | german.hurtado@stepstone.com
38 | StepStone
39 | https://www.stepstone.com
40 |
41 |
42 | Henri David
43 | henri.david@stepstone.com
44 | StepStone
45 | https://www.stepstone.com
46 |
47 |
48 | Hussama Ismail
49 | hussama.ismail@stepstone.com
50 | StepStone
51 | https://www.stepstone.com
52 |
53 |
54 | Stefan Skoruppa
55 | stefan.skoruppa@stepstone.com
56 | StepStone
57 | https://www.stepstone.com
58 |
59 |
60 | Tomasz Wojtun
61 | tomasz.wojtun@stepstone.com
62 | StepStone
63 | https://www.stepstone.com
64 |
65 |
66 | Vinitha Venugopalsavithri
67 | vinitha.venugopalsavithri@stepstone.com
68 | StepStone
69 | https://www.stepstone.com
70 |
71 |
72 | Zhenhua Mai
73 | zhenhua.mai@stepstone.com
74 | StepStone
75 | https://www.stepstone.com
76 |
77 |
78 |
79 |
80 |
81 | Apache License, Version 2.0
82 | http://www.apache.org/licenses/LICENSE-2.0.txt
83 |
84 |
85 |
86 |
87 | hnswlib-jna
88 | hnswlib-jna-legacy
89 | hnswlib-jna-example
90 |
91 |
92 |
93 |
94 | ossrh
95 | https://oss.sonatype.org/content/repositories/snapshots
96 |
97 |
98 |
99 |
100 |
101 |
102 | org.sonatype.plugins
103 | nexus-staging-maven-plugin
104 | 1.6.7
105 | true
106 |
107 | ossrh
108 | https://oss.sonatype.org/
109 | true
110 |
111 |
112 |
113 | org.apache.maven.plugins
114 | maven-source-plugin
115 | 2.2.1
116 |
117 |
118 | attach-sources
119 |
120 | jar-no-fork
121 |
122 |
123 |
124 |
125 |
126 | org.apache.maven.plugins
127 | maven-javadoc-plugin
128 | 2.9.1
129 |
130 |
131 | attach-javadocs
132 |
133 | jar
134 |
135 |
136 |
137 |
138 |
139 | org.apache.maven.plugins
140 | maven-gpg-plugin
141 | 1.5
142 |
143 |
144 | sign-artifacts
145 | verify
146 |
147 | sign
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/hnswlib/bruteforce.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | namespace hnswlib {
8 | template
9 | class BruteforceSearch : public AlgorithmInterface {
10 | public:
11 | BruteforceSearch(SpaceInterface *s) {
12 |
13 | }
14 | BruteforceSearch(SpaceInterface *s, const std::string &location) {
15 | loadIndex(location, s);
16 | }
17 |
18 | BruteforceSearch(SpaceInterface *s, size_t maxElements) {
19 | maxelements_ = maxElements;
20 | data_size_ = s->get_data_size();
21 | fstdistfunc_ = s->get_dist_func();
22 | dist_func_param_ = s->get_dist_func_param();
23 | size_per_element_ = data_size_ + sizeof(labeltype);
24 | data_ = (char *) malloc(maxElements * size_per_element_);
25 | if (data_ == nullptr)
26 | std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
27 | cur_element_count = 0;
28 | }
29 |
30 | ~BruteforceSearch() {
31 | free(data_);
32 | }
33 |
34 | char *data_;
35 | size_t maxelements_;
36 | size_t cur_element_count;
37 | size_t size_per_element_;
38 |
39 | size_t data_size_;
40 | DISTFUNC fstdistfunc_;
41 | void *dist_func_param_;
42 | std::mutex index_lock;
43 |
44 | std::unordered_map dict_external_to_internal;
45 |
46 | void addPoint(const void *datapoint, labeltype label) {
47 |
48 | int idx;
49 | {
50 | std::unique_lock lock(index_lock);
51 |
52 |
53 |
54 | auto search=dict_external_to_internal.find(label);
55 | if (search != dict_external_to_internal.end()) {
56 | idx=search->second;
57 | }
58 | else{
59 | if (cur_element_count >= maxelements_) {
60 | throw std::runtime_error("The number of elements exceeds the specified limit\n");
61 | }
62 | idx=cur_element_count;
63 | dict_external_to_internal[label] = idx;
64 | cur_element_count++;
65 | }
66 | }
67 | memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
68 | memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
69 |
70 |
71 |
72 |
73 | };
74 |
75 | void removePoint(labeltype cur_external) {
76 | size_t cur_c=dict_external_to_internal[cur_external];
77 |
78 | dict_external_to_internal.erase(cur_external);
79 |
80 | labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
81 | dict_external_to_internal[label]=cur_c;
82 | memcpy(data_ + size_per_element_ * cur_c,
83 | data_ + size_per_element_ * (cur_element_count-1),
84 | data_size_+sizeof(labeltype));
85 | cur_element_count--;
86 |
87 | }
88 |
89 |
90 | std::priority_queue>
91 | searchKnn(const void *query_data, size_t k) const {
92 | std::priority_queue> topResults;
93 | if (cur_element_count == 0) return topResults;
94 | for (int i = 0; i < k; i++) {
95 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
96 | topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i +
97 | data_size_))));
98 | }
99 | dist_t lastdist = topResults.top().first;
100 | for (int i = k; i < cur_element_count; i++) {
101 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
102 | if (dist <= lastdist) {
103 | topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i +
104 | data_size_))));
105 | if (topResults.size() > k)
106 | topResults.pop();
107 | lastdist = topResults.top().first;
108 | }
109 |
110 | }
111 | return topResults;
112 | };
113 |
114 | template
115 | std::vector>
116 | searchKnn(const void* query_data, size_t k, Comp comp) {
117 | std::vector> result;
118 | if (cur_element_count == 0) return result;
119 |
120 | auto ret = searchKnn(query_data, k);
121 |
122 | while (!ret.empty()) {
123 | result.push_back(ret.top());
124 | ret.pop();
125 | }
126 |
127 | std::sort(result.begin(), result.end(), comp);
128 |
129 | return result;
130 | }
131 |
132 | void saveIndex(const std::string &location) {
133 | std::ofstream output(location, std::ios::binary);
134 | std::streampos position;
135 |
136 | writeBinaryPOD(output, maxelements_);
137 | writeBinaryPOD(output, size_per_element_);
138 | writeBinaryPOD(output, cur_element_count);
139 |
140 | output.write(data_, maxelements_ * size_per_element_);
141 |
142 | output.close();
143 | }
144 |
145 | void loadIndex(const std::string &location, SpaceInterface *s) {
146 |
147 |
148 | std::ifstream input(location, std::ios::binary);
149 | std::streampos position;
150 |
151 | readBinaryPOD(input, maxelements_);
152 | readBinaryPOD(input, size_per_element_);
153 | readBinaryPOD(input, cur_element_count);
154 |
155 | data_size_ = s->get_data_size();
156 | fstdistfunc_ = s->get_dist_func();
157 | dist_func_param_ = s->get_dist_func_param();
158 | size_per_element_ = data_size_ + sizeof(labeltype);
159 | data_ = (char *) malloc(maxelements_ * size_per_element_);
160 | if (data_ == nullptr)
161 | std::runtime_error("Not enough memory: loadIndex failed to allocate data");
162 |
163 | input.read(data_, maxelements_ * size_per_element_);
164 |
165 | input.close();
166 |
167 | }
168 |
169 | };
170 | }
171 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/ConcurrentIndex.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import java.nio.file.Path;
4 | import java.util.Optional;
5 | import java.util.concurrent.locks.Lock;
6 | import java.util.concurrent.locks.ReadWriteLock;
7 | import java.util.concurrent.locks.ReentrantReadWriteLock;
8 |
9 | /**
10 | * This class offers a thread-safe alternative for a small-world Index.
11 | * It allows concurrent item insertion and querying which is not supported
12 | * by the native Hnswlib implementation.
13 | *
14 | * Note: this class relies on a ReadWriteLock with fairness enabled. So,
15 | * when multi-thread insertions are serialized. To take advantage of parallel
16 | * insertion, please create a Index instance and then retrieve a ConcurrentIndex
17 | * one via Index.synchronizedIndex() method call.
18 | */
19 | public class ConcurrentIndex extends Index {
20 |
21 | private ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true);
22 | private Lock readLock = readWriteLock.readLock();
23 | private Lock writeLock = readWriteLock.writeLock();
24 |
25 | public ConcurrentIndex(SpaceName spaceName, int dimensions) {
26 | super(spaceName, dimensions);
27 | }
28 |
29 | /**
30 | * Thread-safe method which adds an item without ID to the index.
31 | * Internally, an incremental ID (starting from 1) will be given to this item.
32 | *
33 | * @param item - float array with the length expected by the index (dimension).
34 | */
35 | @Override
36 | public void addItem(float[] item) {
37 | this.writeLock.lock();
38 | try {
39 | super.addItem(item, NO_ID);
40 | } finally {
41 | this.writeLock.unlock();
42 | }
43 | }
44 |
45 | /**
46 | * Thread-safe method which adds an item with ID to the index.
47 | * It won't apply any extra normalization unless it is required
48 | * by the Vector Space (e.g., COSINE).
49 | *
50 | * @param item - float array with the length expected by the index (dimension);
51 | * @param id - an identifier used by the native library.
52 | */
53 | @Override
54 | public void addItem(float[] item, int id) {
55 | this.writeLock.lock();
56 | try {
57 | super.addItem(item, id);
58 | } finally {
59 | this.writeLock.unlock();
60 | }
61 | }
62 |
63 | /**
64 | * Thread-safe method which adds a normalized item without ID to the index.
65 | * Internally, an incremental ID (starting from 0) will be given to this item.
66 | *
67 | * @param item - float array with the length expected by the index (dimension).
68 | */
69 | @Override
70 | public void addNormalizedItem(float[] item) {
71 | this.writeLock.lock();
72 | try {
73 | super.addNormalizedItem(item, Index.NO_ID);
74 | } finally {
75 | this.writeLock.unlock();
76 | }
77 | }
78 |
79 | /**
80 | * Thread-safe method which adds a normalized item with ID to the index.
81 | *
82 | * @param item - float array with the length expected by the index (dimension);
83 | * @param id - an identifier used by the native library.
84 | */
85 | @Override
86 | public void addNormalizedItem(float[] item, int id) {
87 | this.writeLock.lock();
88 | try {
89 | super.addNormalizedItem(item, id);
90 | } finally {
91 | this.writeLock.unlock();
92 | }
93 | }
94 |
95 | /**
96 | * Thread-safe method which returns the number of elements
97 | * already inserted in the index.
98 | *
99 | * @return elements count.
100 | */
101 | @Override
102 | public int getLength(){
103 | this.readLock.lock();
104 | try {
105 | return super.getLength();
106 | } finally {
107 | this.readLock.unlock();
108 | }
109 | }
110 |
111 | /**
112 | * Thread-safe method which performs a knn query in the index instance.
113 | * In case the vector space requires the input to be normalized, it will
114 | * normalize at the native level.
115 | *
116 | * @param input - float array;
117 | * @param k - number of results expected.
118 | *
119 | * @return a query tuple instance that contain the indices and coefficients.
120 | */
121 | @Override
122 | public QueryTuple knnQuery(float[] input, int k) {
123 | this.readLock.lock();
124 | QueryTuple queryTuple;
125 | try {
126 | queryTuple = super.knnQuery(input, k);
127 | } finally {
128 | this.readLock.unlock();
129 | }
130 | return queryTuple;
131 | }
132 |
133 | /**
134 | * Thread-safe method which performs a knn query in the index instance
135 | * using an normalized input. It will not normalize the vector again.
136 | *
137 | * @param input - a normalized float array;
138 | * @param k - number of results expected.
139 | *
140 | * @return a query tuple instance that contain the indices and coefficients.
141 | */
142 | @Override
143 | public QueryTuple knnNormalizedQuery(float[] input, int k) {
144 | this.readLock.lock();
145 | QueryTuple queryTuple;
146 | try {
147 | queryTuple = super.knnNormalizedQuery(input, k);
148 | } finally {
149 | this.readLock.unlock();
150 | }
151 | return queryTuple;
152 | }
153 |
154 | /**
155 | * Thread-safe method which stores the content of the index into a file.
156 | * This method relies on the native implementation.
157 | *
158 | * @param path - destination path.
159 | */
160 | @Override
161 | public void save(Path path) {
162 | this.readLock.lock();
163 | try {
164 | super.save(path);
165 | } finally {
166 | this.readLock.unlock();
167 | }
168 | }
169 |
170 | /**
171 | * Thread-safe method which loads the content stored in a file path onto the index.
172 | *
173 | * Note: if the index was previously initialized, the old
174 | * content will be erased.
175 | *
176 | * @param path - path to the index file;
177 | * @param maxNumberOfElements - max number of elements in the index.
178 | */
179 | @Override
180 | public void load(Path path, int maxNumberOfElements) {
181 | this.writeLock.lock();
182 | try {
183 | super.load(path, maxNumberOfElements);
184 | } finally {
185 | this.writeLock.unlock();
186 | }
187 | }
188 |
189 | /**
190 | * Thread-safe method which frees the memory allocated for this index in the native context.
191 | *
192 | * NOTE: Once the index is cleared, it cannot be initialized or used again.
193 | */
194 | @Override
195 | public void clear() {
196 | this.writeLock.lock();
197 | try {
198 | super.clear();
199 | } finally {
200 | this.writeLock.unlock();
201 | }
202 | }
203 |
204 | /**
205 | * Thread-safe method which sets the query time accuracy / speed trade-off value.
206 | *
207 | * @param ef value.
208 | */
209 | @Override
210 | public void setEf(int ef) {
211 | this.writeLock.lock();
212 | try {
213 | super.setEf(ef);
214 | } finally {
215 | this.writeLock.unlock();
216 | }
217 | }
218 |
219 | /**
220 | * Thread-safe method that checks whether there is an item with the specified identifier in the index.
221 | *
222 | * @param id - identifier.
223 | *
224 | * @return true or false.
225 | */
226 | public boolean hasId(int id) {
227 | this.readLock.lock();
228 | boolean hasId;
229 | try {
230 | hasId = super.hasId(id);
231 | } finally {
232 | this.readLock.unlock();
233 | }
234 | return hasId;
235 | }
236 |
237 | /**
238 | * Thread-safe method that marks an ID as deleted.
239 | *
240 | * @param id identifier.
241 | */
242 | public void markDeleted(int id) {
243 | this.writeLock.lock();
244 | try {
245 | super.markDeleted(id);
246 | } finally {
247 | this.writeLock.unlock();
248 | }
249 | }
250 |
251 | /**
252 | * Thread-safe method that gets the data from a specific identifier in the index.
253 | *
254 | * @param id - identifier.
255 | *
256 | * @return an optional containing or not the
257 | */
258 | public Optional getData(int id) {
259 | this.readLock.lock();
260 | Optional data;
261 | try {
262 | data = super.getData(id);
263 | } finally {
264 | this.readLock.unlock();
265 | }
266 | return data;
267 | }
268 |
269 | }
270 |
--------------------------------------------------------------------------------
/hnswlib/space_l2.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "hnswlib.h"
3 |
4 | namespace hnswlib {
5 |
6 | static float
7 | L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
8 | float *pVect1 = (float *) pVect1v;
9 | float *pVect2 = (float *) pVect2v;
10 | size_t qty = *((size_t *) qty_ptr);
11 |
12 | float res = 0;
13 | for (size_t i = 0; i < qty; i++) {
14 | float t = *pVect1 - *pVect2;
15 | pVect1++;
16 | pVect2++;
17 | res += t * t;
18 | }
19 | return (res);
20 | }
21 |
22 | #if defined(USE_AVX)
23 |
24 | // Favor using AVX if available.
25 | static float
26 | L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
27 | float *pVect1 = (float *) pVect1v;
28 | float *pVect2 = (float *) pVect2v;
29 | size_t qty = *((size_t *) qty_ptr);
30 | float PORTABLE_ALIGN32 TmpRes[8];
31 | size_t qty16 = qty >> 4;
32 |
33 | const float *pEnd1 = pVect1 + (qty16 << 4);
34 |
35 | __m256 diff, v1, v2;
36 | __m256 sum = _mm256_set1_ps(0);
37 |
38 | while (pVect1 < pEnd1) {
39 | v1 = _mm256_loadu_ps(pVect1);
40 | pVect1 += 8;
41 | v2 = _mm256_loadu_ps(pVect2);
42 | pVect2 += 8;
43 | diff = _mm256_sub_ps(v1, v2);
44 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
45 |
46 | v1 = _mm256_loadu_ps(pVect1);
47 | pVect1 += 8;
48 | v2 = _mm256_loadu_ps(pVect2);
49 | pVect2 += 8;
50 | diff = _mm256_sub_ps(v1, v2);
51 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
52 | }
53 |
54 | _mm256_store_ps(TmpRes, sum);
55 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
56 | }
57 |
58 | #elif defined(USE_SSE)
59 |
60 | static float
61 | L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
62 | float *pVect1 = (float *) pVect1v;
63 | float *pVect2 = (float *) pVect2v;
64 | size_t qty = *((size_t *) qty_ptr);
65 | float PORTABLE_ALIGN32 TmpRes[8];
66 | size_t qty16 = qty >> 4;
67 |
68 | const float *pEnd1 = pVect1 + (qty16 << 4);
69 |
70 | __m128 diff, v1, v2;
71 | __m128 sum = _mm_set1_ps(0);
72 |
73 | while (pVect1 < pEnd1) {
74 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
75 | v1 = _mm_loadu_ps(pVect1);
76 | pVect1 += 4;
77 | v2 = _mm_loadu_ps(pVect2);
78 | pVect2 += 4;
79 | diff = _mm_sub_ps(v1, v2);
80 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
81 |
82 | v1 = _mm_loadu_ps(pVect1);
83 | pVect1 += 4;
84 | v2 = _mm_loadu_ps(pVect2);
85 | pVect2 += 4;
86 | diff = _mm_sub_ps(v1, v2);
87 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
88 |
89 | v1 = _mm_loadu_ps(pVect1);
90 | pVect1 += 4;
91 | v2 = _mm_loadu_ps(pVect2);
92 | pVect2 += 4;
93 | diff = _mm_sub_ps(v1, v2);
94 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
95 |
96 | v1 = _mm_loadu_ps(pVect1);
97 | pVect1 += 4;
98 | v2 = _mm_loadu_ps(pVect2);
99 | pVect2 += 4;
100 | diff = _mm_sub_ps(v1, v2);
101 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
102 | }
103 |
104 | _mm_store_ps(TmpRes, sum);
105 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
106 | }
107 | #endif
108 |
109 | #if defined(USE_SSE) || defined(USE_AVX)
110 | static float
111 | L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
112 | size_t qty = *((size_t *) qty_ptr);
113 | size_t qty16 = qty >> 4 << 4;
114 | float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
115 | float *pVect1 = (float *) pVect1v + qty16;
116 | float *pVect2 = (float *) pVect2v + qty16;
117 |
118 | size_t qty_left = qty - qty16;
119 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
120 | return (res + res_tail);
121 | }
122 | #endif
123 |
124 |
125 | #ifdef USE_SSE
126 | static float
127 | L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
128 | float PORTABLE_ALIGN32 TmpRes[8];
129 | float *pVect1 = (float *) pVect1v;
130 | float *pVect2 = (float *) pVect2v;
131 | size_t qty = *((size_t *) qty_ptr);
132 |
133 |
134 | size_t qty4 = qty >> 2;
135 |
136 | const float *pEnd1 = pVect1 + (qty4 << 2);
137 |
138 | __m128 diff, v1, v2;
139 | __m128 sum = _mm_set1_ps(0);
140 |
141 | while (pVect1 < pEnd1) {
142 | v1 = _mm_loadu_ps(pVect1);
143 | pVect1 += 4;
144 | v2 = _mm_loadu_ps(pVect2);
145 | pVect2 += 4;
146 | diff = _mm_sub_ps(v1, v2);
147 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
148 | }
149 | _mm_store_ps(TmpRes, sum);
150 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
151 | }
152 |
153 | static float
154 | L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
155 | size_t qty = *((size_t *) qty_ptr);
156 | size_t qty4 = qty >> 2 << 2;
157 |
158 | float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
159 | size_t qty_left = qty - qty4;
160 |
161 | float *pVect1 = (float *) pVect1v + qty4;
162 | float *pVect2 = (float *) pVect2v + qty4;
163 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
164 |
165 | return (res + res_tail);
166 | }
167 | #endif
168 |
169 | class L2Space : public SpaceInterface {
170 |
171 | DISTFUNC fstdistfunc_;
172 | size_t data_size_;
173 | size_t dim_;
174 | public:
175 | L2Space(size_t dim) {
176 | fstdistfunc_ = L2Sqr;
177 | #if defined(USE_SSE) || defined(USE_AVX)
178 | if (dim % 16 == 0)
179 | fstdistfunc_ = L2SqrSIMD16Ext;
180 | else if (dim % 4 == 0)
181 | fstdistfunc_ = L2SqrSIMD4Ext;
182 | else if (dim > 16)
183 | fstdistfunc_ = L2SqrSIMD16ExtResiduals;
184 | else if (dim > 4)
185 | fstdistfunc_ = L2SqrSIMD4ExtResiduals;
186 | #endif
187 | dim_ = dim;
188 | data_size_ = dim * sizeof(float);
189 | }
190 |
191 | size_t get_data_size() {
192 | return data_size_;
193 | }
194 |
195 | DISTFUNC get_dist_func() {
196 | return fstdistfunc_;
197 | }
198 |
199 | void *get_dist_func_param() {
200 | return &dim_;
201 | }
202 |
203 | ~L2Space() {}
204 | };
205 |
206 | static int
207 | L2SqrI(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
208 |
209 | size_t qty = *((size_t *) qty_ptr);
210 | int res = 0;
211 | unsigned char *a = (unsigned char *) pVect1;
212 | unsigned char *b = (unsigned char *) pVect2;
213 |
214 | qty = qty >> 2;
215 | for (size_t i = 0; i < qty; i++) {
216 |
217 | res += ((*a) - (*b)) * ((*a) - (*b));
218 | a++;
219 | b++;
220 | res += ((*a) - (*b)) * ((*a) - (*b));
221 | a++;
222 | b++;
223 | res += ((*a) - (*b)) * ((*a) - (*b));
224 | a++;
225 | b++;
226 | res += ((*a) - (*b)) * ((*a) - (*b));
227 | a++;
228 | b++;
229 |
230 |
231 | }
232 |
233 | return (res);
234 |
235 | }
236 |
237 | class L2SpaceI : public SpaceInterface {
238 |
239 | DISTFUNC fstdistfunc_;
240 | size_t data_size_;
241 | size_t dim_;
242 | public:
243 | L2SpaceI(size_t dim) {
244 | fstdistfunc_ = L2SqrI;
245 | dim_ = dim;
246 | data_size_ = dim * sizeof(unsigned char);
247 | }
248 |
249 | size_t get_data_size() {
250 | return data_size_;
251 | }
252 |
253 | DISTFUNC get_dist_func() {
254 | return fstdistfunc_;
255 | }
256 |
257 | void *get_dist_func_param() {
258 | return &dim_;
259 | }
260 |
261 | ~L2SpaceI() {}
262 | };
263 |
264 |
265 | }
--------------------------------------------------------------------------------
/hnswlib/space_ip.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "hnswlib.h"
3 |
4 | namespace hnswlib {
5 |
6 | static float
7 | InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
8 | size_t qty = *((size_t *) qty_ptr);
9 | float res = 0;
10 | for (unsigned i = 0; i < qty; i++) {
11 | res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
12 | }
13 | return (1.0f - res);
14 |
15 | }
16 |
17 | #if defined(USE_AVX)
18 |
19 | // Favor using AVX if available.
20 | static float
21 | InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
22 | float PORTABLE_ALIGN32 TmpRes[8];
23 | float *pVect1 = (float *) pVect1v;
24 | float *pVect2 = (float *) pVect2v;
25 | size_t qty = *((size_t *) qty_ptr);
26 |
27 | size_t qty16 = qty / 16;
28 | size_t qty4 = qty / 4;
29 |
30 | const float *pEnd1 = pVect1 + 16 * qty16;
31 | const float *pEnd2 = pVect1 + 4 * qty4;
32 |
33 | __m256 sum256 = _mm256_set1_ps(0);
34 |
35 | while (pVect1 < pEnd1) {
36 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
37 |
38 | __m256 v1 = _mm256_loadu_ps(pVect1);
39 | pVect1 += 8;
40 | __m256 v2 = _mm256_loadu_ps(pVect2);
41 | pVect2 += 8;
42 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
43 |
44 | v1 = _mm256_loadu_ps(pVect1);
45 | pVect1 += 8;
46 | v2 = _mm256_loadu_ps(pVect2);
47 | pVect2 += 8;
48 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
49 | }
50 |
51 | __m128 v1, v2;
52 | __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
53 |
54 | while (pVect1 < pEnd2) {
55 | v1 = _mm_loadu_ps(pVect1);
56 | pVect1 += 4;
57 | v2 = _mm_loadu_ps(pVect2);
58 | pVect2 += 4;
59 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
60 | }
61 |
62 | _mm_store_ps(TmpRes, sum_prod);
63 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];;
64 | return 1.0f - sum;
65 | }
66 |
67 | #elif defined(USE_SSE)
68 |
69 | static float
70 | InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
71 | float PORTABLE_ALIGN32 TmpRes[8];
72 | float *pVect1 = (float *) pVect1v;
73 | float *pVect2 = (float *) pVect2v;
74 | size_t qty = *((size_t *) qty_ptr);
75 |
76 | size_t qty16 = qty / 16;
77 | size_t qty4 = qty / 4;
78 |
79 | const float *pEnd1 = pVect1 + 16 * qty16;
80 | const float *pEnd2 = pVect1 + 4 * qty4;
81 |
82 | __m128 v1, v2;
83 | __m128 sum_prod = _mm_set1_ps(0);
84 |
85 | while (pVect1 < pEnd1) {
86 | v1 = _mm_loadu_ps(pVect1);
87 | pVect1 += 4;
88 | v2 = _mm_loadu_ps(pVect2);
89 | pVect2 += 4;
90 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
91 |
92 | v1 = _mm_loadu_ps(pVect1);
93 | pVect1 += 4;
94 | v2 = _mm_loadu_ps(pVect2);
95 | pVect2 += 4;
96 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
97 |
98 | v1 = _mm_loadu_ps(pVect1);
99 | pVect1 += 4;
100 | v2 = _mm_loadu_ps(pVect2);
101 | pVect2 += 4;
102 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
103 |
104 | v1 = _mm_loadu_ps(pVect1);
105 | pVect1 += 4;
106 | v2 = _mm_loadu_ps(pVect2);
107 | pVect2 += 4;
108 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
109 | }
110 |
111 | while (pVect1 < pEnd2) {
112 | v1 = _mm_loadu_ps(pVect1);
113 | pVect1 += 4;
114 | v2 = _mm_loadu_ps(pVect2);
115 | pVect2 += 4;
116 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
117 | }
118 |
119 | _mm_store_ps(TmpRes, sum_prod);
120 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
121 |
122 | return 1.0f - sum;
123 | }
124 |
125 | #endif
126 |
127 | #if defined(USE_AVX)
128 |
129 | static float
130 | InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
131 | float PORTABLE_ALIGN32 TmpRes[8];
132 | float *pVect1 = (float *) pVect1v;
133 | float *pVect2 = (float *) pVect2v;
134 | size_t qty = *((size_t *) qty_ptr);
135 |
136 | size_t qty16 = qty / 16;
137 |
138 |
139 | const float *pEnd1 = pVect1 + 16 * qty16;
140 |
141 | __m256 sum256 = _mm256_set1_ps(0);
142 |
143 | while (pVect1 < pEnd1) {
144 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
145 |
146 | __m256 v1 = _mm256_loadu_ps(pVect1);
147 | pVect1 += 8;
148 | __m256 v2 = _mm256_loadu_ps(pVect2);
149 | pVect2 += 8;
150 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
151 |
152 | v1 = _mm256_loadu_ps(pVect1);
153 | pVect1 += 8;
154 | v2 = _mm256_loadu_ps(pVect2);
155 | pVect2 += 8;
156 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
157 | }
158 |
159 | _mm256_store_ps(TmpRes, sum256);
160 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
161 |
162 | return 1.0f - sum;
163 | }
164 |
165 | #elif defined(USE_SSE)
166 |
167 | static float
168 | InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
169 | float PORTABLE_ALIGN32 TmpRes[8];
170 | float *pVect1 = (float *) pVect1v;
171 | float *pVect2 = (float *) pVect2v;
172 | size_t qty = *((size_t *) qty_ptr);
173 |
174 | size_t qty16 = qty / 16;
175 |
176 | const float *pEnd1 = pVect1 + 16 * qty16;
177 |
178 | __m128 v1, v2;
179 | __m128 sum_prod = _mm_set1_ps(0);
180 |
181 | while (pVect1 < pEnd1) {
182 | v1 = _mm_loadu_ps(pVect1);
183 | pVect1 += 4;
184 | v2 = _mm_loadu_ps(pVect2);
185 | pVect2 += 4;
186 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
187 |
188 | v1 = _mm_loadu_ps(pVect1);
189 | pVect1 += 4;
190 | v2 = _mm_loadu_ps(pVect2);
191 | pVect2 += 4;
192 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
193 |
194 | v1 = _mm_loadu_ps(pVect1);
195 | pVect1 += 4;
196 | v2 = _mm_loadu_ps(pVect2);
197 | pVect2 += 4;
198 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
199 |
200 | v1 = _mm_loadu_ps(pVect1);
201 | pVect1 += 4;
202 | v2 = _mm_loadu_ps(pVect2);
203 | pVect2 += 4;
204 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
205 | }
206 | _mm_store_ps(TmpRes, sum_prod);
207 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
208 |
209 | return 1.0f - sum;
210 | }
211 |
212 | #endif
213 |
214 | #if defined(USE_SSE) || defined(USE_AVX)
215 | static float
216 | InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
217 | size_t qty = *((size_t *) qty_ptr);
218 | size_t qty16 = qty >> 4 << 4;
219 | float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
220 | float *pVect1 = (float *) pVect1v + qty16;
221 | float *pVect2 = (float *) pVect2v + qty16;
222 |
223 | size_t qty_left = qty - qty16;
224 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
225 | return res + res_tail - 1.0f;
226 | }
227 |
228 | static float
229 | InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
230 | size_t qty = *((size_t *) qty_ptr);
231 | size_t qty4 = qty >> 2 << 2;
232 |
233 | float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
234 | size_t qty_left = qty - qty4;
235 |
236 | float *pVect1 = (float *) pVect1v + qty4;
237 | float *pVect2 = (float *) pVect2v + qty4;
238 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
239 |
240 | return res + res_tail - 1.0f;
241 | }
242 | #endif
243 |
244 | class InnerProductSpace : public SpaceInterface {
245 |
246 | DISTFUNC fstdistfunc_;
247 | size_t data_size_;
248 | size_t dim_;
249 | public:
250 | InnerProductSpace(size_t dim) {
251 | fstdistfunc_ = InnerProduct;
252 | #if defined(USE_AVX) || defined(USE_SSE)
253 | if (dim % 16 == 0)
254 | fstdistfunc_ = InnerProductSIMD16Ext;
255 | else if (dim % 4 == 0)
256 | fstdistfunc_ = InnerProductSIMD4Ext;
257 | else if (dim > 16)
258 | fstdistfunc_ = InnerProductSIMD16ExtResiduals;
259 | else if (dim > 4)
260 | fstdistfunc_ = InnerProductSIMD4ExtResiduals;
261 | #endif
262 | dim_ = dim;
263 | data_size_ = dim * sizeof(float);
264 | }
265 |
266 | size_t get_data_size() {
267 | return data_size_;
268 | }
269 |
270 | DISTFUNC get_dist_func() {
271 | return fstdistfunc_;
272 | }
273 |
274 | void *get_dist_func_param() {
275 | return &dim_;
276 | }
277 |
278 | ~InnerProductSpace() {}
279 | };
280 |
281 |
282 | }
283 |
--------------------------------------------------------------------------------
/bindings.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Simplified C binding for hnswlib able to work with JNA (Java Native Access)
3 | * in order to get a similar native performance on Java. This code is based on the python
4 | * binding available at: https://github.com/nmslib/hnswlib/blob/master/python_bindings/bindings.cpp
5 | *
6 | * Some modifications and simplifications have been done on the C side.
7 | * The multithread support can be used and handled on the Java side.
8 | *
9 | * This work is still in progress. Please feel free to contribute and give ideas.
10 | */
11 |
12 | #include
13 | #include
14 | #include
15 | #include "hnswlib/hnswlib.h"
16 |
17 | #if _WIN32
18 | #define DLLEXPORT __declspec(dllexport)
19 | #else
20 | #define DLLEXPORT
21 | #endif
22 |
23 | #define EXTERN_C extern "C"
24 |
25 | #define RESULT_SUCCESSFUL 0
26 | #define RESULT_EXCEPTION_THROWN 1
27 | #define RESULT_INDEX_ALREADY_INITIALIZED 2
28 | #define RESULT_QUERY_CANNOT_RETURN 3
29 | #define RESULT_ITEM_CANNOT_BE_INSERTED_INTO_THE_VECTOR_SPACE 4
30 | #define RESULT_ONCE_INDEX_IS_CLEARED_IT_CANNOT_BE_REUSED 5
31 | #define RESULT_GET_DATA_FAILED 6
32 | #define RESULT_ID_NOT_IN_INDEX 7
33 | #define RESULT_INDEX_NOT_INITIALIZED 8
34 |
35 | #define TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK(block) if (index_cleared) return RESULT_ONCE_INDEX_IS_CLEARED_IT_CANNOT_BE_REUSED; int result_code = RESULT_SUCCESSFUL; try { block } catch (...) { result_code = RESULT_EXCEPTION_THROWN; }; return result_code;
36 | #define TRY_CATCH_RETURN_INT_BLOCK(block) if (!index_initialized) return RESULT_INDEX_NOT_INITIALIZED; TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK(block)
37 |
38 | template
39 | class Index {
40 | public:
41 | Index(const std::string &space_name, const int dim) :
42 | space_name(space_name), dim(dim) {
43 | data_must_be_normalized = false;
44 | if(space_name=="L2") {
45 | l2space = new hnswlib::L2Space(dim);
46 | } else if(space_name=="IP") {
47 | l2space = new hnswlib::InnerProductSpace(dim);
48 | } else if(space_name=="COSINE") {
49 | l2space = new hnswlib::InnerProductSpace(dim);
50 | data_must_be_normalized = true;
51 | }
52 | appr_alg = NULL;
53 | index_initialized = false;
54 | index_cleared = false;
55 | }
56 |
57 | int init_new_index(const size_t maxElements, const size_t M, const size_t efConstruction, const size_t random_seed) {
58 | TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK({
59 | if (appr_alg) {
60 | return RESULT_INDEX_ALREADY_INITIALIZED;
61 | }
62 | appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed);
63 | index_initialized = true;
64 | });
65 | }
66 |
67 | int set_ef(size_t ef) {
68 | TRY_CATCH_RETURN_INT_BLOCK({
69 | appr_alg->ef_ = ef;
70 | });
71 | }
72 |
73 | int get_ef() {
74 | return appr_alg->ef_;
75 | }
76 |
77 | int get_ef_construction() {
78 | return appr_alg->ef_construction_;
79 | }
80 |
81 | int get_M() {
82 | return appr_alg->M_;
83 | }
84 |
85 | int save_index(const std::string &path_to_index) {
86 | TRY_CATCH_RETURN_INT_BLOCK({
87 | appr_alg->saveIndex(path_to_index);
88 | });
89 | }
90 |
91 | int load_index(const std::string &path_to_index, size_t max_elements) {
92 | TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK({
93 | if (appr_alg) {
94 | std::cerr << "Warning: Calling load_index for an already initialized index. Old index is being deallocated.";
95 | delete appr_alg;
96 | }
97 | appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements);
98 | });
99 | }
100 |
101 | void normalize_array(float* array){
102 | float norm = 0.0f;
103 | for (int i=0; i= get_max_elements()) {
115 | return RESULT_ITEM_CANNOT_BE_INSERTED_INTO_THE_VECTOR_SPACE;
116 | }
117 | if ((data_must_be_normalized == true) && (item_normalized == false)) {
118 | normalize_array(item);
119 | }
120 | int current_id = id != -1 ? id : incremental_id++;
121 | appr_alg->addPoint(item, current_id);
122 | });
123 | }
124 |
125 | int hasId(int id) {
126 | TRY_CATCH_RETURN_INT_BLOCK({
127 | int label_c;
128 | auto search = (appr_alg->label_lookup_.find(id));
129 | if (search == (appr_alg->label_lookup_.end()) || (appr_alg->isMarkedDeleted(search->second))) {
130 | return RESULT_ID_NOT_IN_INDEX;
131 | }
132 | });
133 | }
134 |
135 | int getDataById(int id, float* data, int dim) {
136 | TRY_CATCH_RETURN_INT_BLOCK({
137 | int label_c;
138 | auto search = (appr_alg->label_lookup_.find(id));
139 | if (search == (appr_alg->label_lookup_.end()) || (appr_alg->isMarkedDeleted(search->second))) {
140 | return RESULT_ID_NOT_IN_INDEX;
141 | }
142 | label_c = search->second;
143 | char* data_ptrv = (appr_alg->getDataByInternalId(label_c));
144 | float* data_ptr = (float*) data_ptrv;
145 | for (int i = 0; i < dim; i++) {
146 | data[i] = *data_ptr;
147 | data_ptr += 1;
148 | }
149 | });
150 | }
151 |
152 | float compute_similarity(float* vector1, float* vector2) {
153 | float similarity;
154 | try {
155 | similarity = (appr_alg->fstdistfunc_(vector1, vector2, (appr_alg -> dist_func_param_)));
156 | } catch (...) {
157 | similarity = NAN;
158 | }
159 | return similarity;
160 | }
161 |
162 | int knn_query(float* input, bool input_normalized, int k, int* indices /* output */, float* coefficients /* output */) {
163 | std::priority_queue> result;
164 | TRY_CATCH_RETURN_INT_BLOCK({
165 | if ((data_must_be_normalized == true) && (input_normalized == false)) {
166 | normalize_array(input);
167 | }
168 | result = appr_alg->searchKnn((void*) input, k);
169 | if (result.size() != k)
170 | return RESULT_QUERY_CANNOT_RETURN;
171 | for (int i = k - 1; i >= 0; i--) {
172 | auto &result_tuple = result.top();
173 | coefficients[i] = result_tuple.first;
174 | indices[i] = result_tuple.second;
175 | result.pop();
176 | }
177 | });
178 | }
179 |
180 | int mark_deleted(int label) {
181 | TRY_CATCH_RETURN_INT_BLOCK({
182 | appr_alg->markDelete(label);
183 | });
184 | }
185 |
186 | void resize_index(size_t new_size) {
187 | appr_alg->resizeIndex(new_size);
188 | }
189 |
190 | int get_max_elements() const {
191 | return appr_alg->max_elements_;
192 | }
193 |
194 | int get_current_count() const {
195 | return appr_alg->cur_element_count;
196 | }
197 |
198 | int clear_index() {
199 | TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK({
200 | delete l2space;
201 | if (appr_alg)
202 | delete appr_alg;
203 | index_cleared = true;
204 | });
205 | }
206 |
207 | std::string space_name;
208 | int dim;
209 | bool index_cleared;
210 | bool index_initialized;
211 | bool data_must_be_normalized;
212 | std::atomic incremental_id{0};
213 | hnswlib::HierarchicalNSW *appr_alg;
214 | hnswlib::SpaceInterface *l2space;
215 |
216 | ~Index() {
217 | clear_index();
218 | }
219 | };
220 |
221 | EXTERN_C DLLEXPORT Index* createNewIndex(char* spaceName, int dimension){
222 | Index* index;
223 | try {
224 | index = new Index(spaceName, dimension);
225 | } catch (...) {
226 | index = NULL;
227 | }
228 | return index;
229 | }
230 |
231 | EXTERN_C DLLEXPORT int initNewIndex(Index* index, int maxNumberOfElements, int M = 16, int efConstruction = 200, int randomSeed = 100) {
232 | return index->init_new_index(maxNumberOfElements, M, efConstruction, randomSeed);
233 | }
234 |
235 | EXTERN_C DLLEXPORT int addItemToIndex(float* item, int normalized, int label, Index* index) {
236 | return index->add_item(item, normalized, label);
237 | }
238 |
239 | EXTERN_C DLLEXPORT int getIndexLength(Index* index) {
240 | if (index->appr_alg) {
241 | return index->appr_alg->cur_element_count;
242 | } else {
243 | return 0;
244 | }
245 | }
246 |
247 | EXTERN_C DLLEXPORT int saveIndexToPath(Index* index, char* path) {
248 | std::string path_string(path);
249 | return index->save_index(path_string);
250 | }
251 |
252 | EXTERN_C DLLEXPORT int loadIndexFromPath(Index* index, size_t maxNumberOfElements, char* path) {
253 | std::string path_string(path);
254 | return index->load_index(path_string, maxNumberOfElements);
255 | }
256 |
257 | EXTERN_C DLLEXPORT int knnQuery(Index* index, float* input, int normalized, int k, int* indices /* output */, float* coefficients /* output */) {
258 | return index->knn_query(input, normalized, k, indices, coefficients);
259 | }
260 |
261 | EXTERN_C DLLEXPORT int clearIndex(Index* index) {
262 | return index->clear_index();
263 | }
264 |
265 | EXTERN_C DLLEXPORT int setEf(Index* index, int ef) {
266 | return index->set_ef(ef);
267 | }
268 |
269 | EXTERN_C DLLEXPORT int getData(Index* index, int id, float* vector, int dim) {
270 | return index->getDataById(id, vector, dim);
271 | }
272 |
273 | EXTERN_C DLLEXPORT int hasId(Index* index, int id) {
274 | return index->hasId(id);
275 | }
276 |
277 | EXTERN_C DLLEXPORT float computeSimilarity(Index* index, float* vector1, float* vector2) {
278 | return index->compute_similarity(vector1, vector2);
279 | }
280 |
281 | EXTERN_C DLLEXPORT int getM(Index* index) {
282 | return index->get_M();
283 | }
284 |
285 | EXTERN_C DLLEXPORT int getEfConstruction(Index* index) {
286 | return index->get_ef_construction();
287 | }
288 |
289 | EXTERN_C DLLEXPORT int getEf(Index* index) {
290 | return index->get_ef();
291 | }
292 |
293 | EXTERN_C DLLEXPORT int markDeleted(Index* index, int id) {
294 | return index->mark_deleted(id);
295 | }
296 |
297 | int main(){
298 | return RESULT_SUCCESSFUL;
299 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/Index.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import com.stepstone.search.hnswlib.jna.exception.IndexAlreadyInitializedException;
4 | import com.stepstone.search.hnswlib.jna.exception.IndexNotInitializedException;
5 | import com.stepstone.search.hnswlib.jna.exception.ItemCannotBeInsertedIntoTheVectorSpaceException;
6 | import com.stepstone.search.hnswlib.jna.exception.OnceIndexIsClearedItCannotBeReusedException;
7 | import com.stepstone.search.hnswlib.jna.exception.QueryCannotReturnResultsException;
8 | import com.stepstone.search.hnswlib.jna.exception.UnableToCreateNewIndexInstanceException;
9 | import com.stepstone.search.hnswlib.jna.exception.UnexpectedNativeException;
10 | import com.sun.jna.Pointer;
11 | import it.unimi.dsi.fastutil.ints.IntArraySet;
12 | import it.unimi.dsi.fastutil.ints.IntSet;
13 | import it.unimi.dsi.fastutil.ints.IntSets;
14 |
15 | import java.nio.file.Path;
16 | import java.util.Optional;
17 |
18 | /**
19 | * Represents a small world index in the java context.
20 | * This class includes some abstraction to make the integration
21 | * with the native library a bit more java like and relies on the
22 | * JNA implementation.
23 | *
24 | * Each instance of index has a different memory context and should
25 | * work independently.
26 | */
27 | public class Index {
28 |
29 | protected static final int NO_ID = -1;
30 | private static final int RESULT_SUCCESSFUL = 0;
31 | private static final int RESULT_QUERY_NO_RESULTS = 3;
32 | private static final int RESULT_ITEM_CANNOT_BE_INSERTED_INTO_THE_VECTOR_SPACE = 4;
33 | private static final int RESULT_ONCE_INDEX_IS_CLEARED_IT_CANNOT_BE_REUSED = 5;
34 | private static final int RESULT_INDEX_NOT_INITIALIZED = 8;
35 |
36 | private static Hnswlib hnswlib = HnswlibFactory.getInstance();
37 |
38 | private Pointer reference;
39 | private boolean initialized;
40 | private boolean cleared;
41 | private SpaceName spaceName;
42 | private int dimension;
43 | private boolean referenceReused;
44 | private IntSet ids = IntSets.synchronize(new IntArraySet());
45 |
46 | public Index(SpaceName spaceName, int dimension) {
47 | this.spaceName = spaceName;
48 | this.dimension = dimension;
49 | reference = hnswlib.createNewIndex(spaceName.toString(), dimension);
50 | if (reference == null) {
51 | throw new UnableToCreateNewIndexInstanceException();
52 | }
53 | }
54 |
55 | /**
56 | * This method initializes the index with the default values
57 | * for the parameters m, efConstruction, randomSeed and sets
58 | * the maxNumberOfElements to 1_000_000.
59 | *
60 | * Note: not setting the maxNumberOfElements might lead to out of memory
61 | * issues and unpredictable behaviours in your application. Thus, use this
62 | * method wisely and combine it with monitoring.
63 | *
64 | * For more information, please @see {link #initialize(int, int, int, int)}.
65 | */
66 | public void initialize() {
67 | initialize(1_000_000);
68 | }
69 |
70 | /**
71 | * For more information, please @see {link #initialize(int, int, int, int)}.
72 | *
73 | * @param maxNumberOfElements allowed in the index.
74 | */
75 | public void initialize(int maxNumberOfElements) {
76 | initialize(maxNumberOfElements, 16, 200, 100);
77 | }
78 |
79 | /**
80 | * Initialize the index to be used.
81 | *
82 | * @param maxNumberOfElements ;
83 | * @param m ;
84 | * @param efConstruction ;
85 | * @param randomSeed .
86 | *
87 | * @throws IndexAlreadyInitializedException when a index reference was initialized before.
88 | * @throws UnexpectedNativeException when something unexpected happened in the native side.
89 | */
90 | public void initialize(int maxNumberOfElements, int m, int efConstruction, int randomSeed) {
91 | if (initialized) {
92 | throw new IndexAlreadyInitializedException();
93 | } else {
94 | checkResultCode(hnswlib.initNewIndex(reference, maxNumberOfElements, m, efConstruction, randomSeed));
95 | initialized = true;
96 | }
97 | }
98 |
99 | /**
100 | * Sets the query time accuracy / speed trade-off value.
101 | *
102 | * @param ef value.
103 | */
104 | public void setEf(int ef) {
105 | checkResultCode(hnswlib.setEf(reference, ef));
106 | }
107 |
108 | /**
109 | * Add an item without ID to the index. Internally, an incremental
110 | * identifier (starting from 1) will be given to this item.
111 | *
112 | * @param item - float array with the length expected by the index (dimension).
113 | */
114 | public void addItem(float[] item) {
115 | addItem(item, NO_ID);
116 | }
117 |
118 | /**
119 | * Add an item with ID to the index. It won't apply any extra normalization
120 | * unless it is required by the Vector Space (e.g., COSINE).
121 | *
122 | * @param item - float array with the length expected by the index (dimension);
123 | * @param id - an identifier used by the native library.
124 | */
125 | public void addItem(float[] item, int id) {
126 | checkResultCode(hnswlib.addItemToIndex(item, false, id, reference));
127 | }
128 |
129 | /**
130 | * Add an item with ID to the index. It won't apply any extra normalization
131 | * unless it is required by the Vector Space (e.g., COSINE).
132 | * Save id to internal collection if saveId = true
133 | * @param item item
134 | * @param id id
135 | * @param saveId true to save id to internal collection
136 | */
137 | public void addItem(float[] item, int id, boolean saveId) {
138 | addItem(item, id);
139 | if (saveId) {
140 | ids.add(id);
141 | }
142 | }
143 |
144 | public void setIds(IntSet ids) {
145 | this.ids = ids;
146 | }
147 |
148 | /**
149 | * Get ids for items in this index
150 | * @return set of ids
151 | */
152 | public IntSet getIds() {
153 | return ids;
154 | }
155 |
156 | /**
157 | * Add a normalized item without ID to the index. Internally, an incremental
158 | * ID (starting from 0) will be given to this item.
159 | *
160 | * @param item - float array with the length expected by the index (dimension).
161 | */
162 | public void addNormalizedItem(float[] item) {
163 | addNormalizedItem(item, NO_ID);
164 | }
165 |
166 | /**
167 | * Add a normalized item with ID to the index.
168 | *
169 | * @param item - float array with the length expected by the index (dimension);
170 | * @param id - an identifier used by the native library.
171 | */
172 | public void addNormalizedItem(float[] item, int id) {
173 | checkResultCode(hnswlib.addItemToIndex(item, true, id, reference));
174 | }
175 |
176 | /**
177 | * Add a normalized item with ID to the index.
178 | *
179 | * @param item - float array with the length expected by the index (dimension);
180 | * @param id - an identifier used by the native library.
181 | */
182 | public void addNormalizedItem(float[] item, int id, boolean saveId) {
183 | addNormalizedItem(item, id);
184 | if (saveId) {
185 | ids.add(id);
186 | }
187 | }
188 |
189 | /**
190 | * Return the number of elements already inserted in
191 | * the index.
192 | *
193 | * @return elements count.
194 | */
195 | public int getLength(){
196 | return hnswlib.getIndexLength(reference);
197 | }
198 |
199 | /**
200 | * Performs a knn query in the index instance. In case the vector space requires
201 | * the input to be normalized, it will normalize at the native level.
202 | *
203 | * @param input - float array;
204 | * @param k - number of results expected.
205 | *
206 | * @return a query tuple instance that contain the indices and coefficients.
207 | */
208 | public QueryTuple knnQuery(float[] input, int k) {
209 | QueryTuple queryTuple = new QueryTuple(k);
210 | checkResultCode(hnswlib.knnQuery(reference, input, false, k, queryTuple.ids, queryTuple.coefficients));
211 | return queryTuple;
212 | }
213 |
214 | /**
215 | * Performs a knn query in the index instance using an normalized input.
216 | * It will not normalize the vector again.
217 | *
218 | * @param input - a normalized float array;
219 | * @param k - number of results expected.
220 | *
221 | * @return a query tuple instance that contain the indices and coefficients.
222 | */
223 | public QueryTuple knnNormalizedQuery(float[] input, int k) {
224 | QueryTuple queryTuple = new QueryTuple(k);
225 | checkResultCode(hnswlib.knnQuery(reference, input, true, k, queryTuple.ids, queryTuple.coefficients));
226 | return queryTuple;
227 | }
228 |
229 | /**
230 | * Stores the content of the index into a file.
231 | * This method relies on the native implementation.
232 | *
233 | * @param path - destination path.
234 | */
235 | public void save(Path path) {
236 | checkResultCode(hnswlib.saveIndexToPath(reference, path.toAbsolutePath().toString()));
237 | }
238 |
239 | /**
240 | * This method loads the content stored in a file path onto the index.
241 | *
242 | * Note: if the index was previously initialized, the old
243 | * content will be erased.
244 | *
245 | * @param path - path to the index file;
246 | * @param maxNumberOfElements - max number of elements in the index.
247 | */
248 | public void load(Path path, int maxNumberOfElements) {
249 | checkResultCode(hnswlib.loadIndexFromPath(reference, maxNumberOfElements, path.toAbsolutePath().toString()));
250 | }
251 |
252 | /**
253 | * Free the memory allocated for this index in the native context.
254 | *
255 | * NOTE: Once the index is cleared, it cannot be initialized or used again.
256 | */
257 | public void clear() {
258 | checkResultCode(hnswlib.clearIndex(reference));
259 | cleared = true;
260 | }
261 |
262 | /**
263 | * Cleanup the area allocated by the index in the native side.
264 | *
265 | * @throws Throwable when anything weird happened. :)
266 | */
267 | @Override
268 | protected void finalize() throws Throwable {
269 | if (!cleared && !referenceReused) {
270 | this.clear();
271 | }
272 | super.finalize();
273 | }
274 |
275 | /**
276 | * This method checks the result code coming from the
277 | * native execution is correct otherwise throws an exception.
278 | *
279 | * @throws UnexpectedNativeException when something went out of control in the native side.
280 | */
281 | private void checkResultCode(int resultCode) {
282 | switch (resultCode) {
283 | case RESULT_SUCCESSFUL:
284 | break;
285 | case RESULT_QUERY_NO_RESULTS:
286 | throw new QueryCannotReturnResultsException();
287 | case RESULT_ITEM_CANNOT_BE_INSERTED_INTO_THE_VECTOR_SPACE:
288 | throw new ItemCannotBeInsertedIntoTheVectorSpaceException();
289 | case RESULT_ONCE_INDEX_IS_CLEARED_IT_CANNOT_BE_REUSED:
290 | throw new OnceIndexIsClearedItCannotBeReusedException();
291 | case RESULT_INDEX_NOT_INITIALIZED:
292 | throw new IndexNotInitializedException();
293 | default:
294 | throw new UnexpectedNativeException();
295 | }
296 | }
297 |
298 | /**
299 | * Checks whether there is an item with the specified identifier in the index.
300 | *
301 | * @param id - identifier.
302 | * @return true or false.
303 | */
304 | public boolean hasId(int id) {
305 | return hnswlib.hasId(reference, id) == RESULT_SUCCESSFUL;
306 | }
307 |
308 | /**
309 | * Gets the data from a specific identifier in the index.
310 | *
311 | * @param id - identifier.
312 | *
313 | * @return an optional containing or not the
314 | */
315 | public Optional getData(int id) {
316 | float[] vector = new float[dimension];
317 | int success = hnswlib.getData(reference, id, vector, dimension);
318 | if (success == RESULT_SUCCESSFUL) {
319 | return Optional.of(vector);
320 | }
321 | return Optional.empty();
322 | }
323 |
324 | /**
325 | * Computer similarity on the native side taking advantage of
326 | * SSE, AVX, SIMD instructions, when available.
327 | *
328 | * @param vector1 array with correct dimension;
329 | * @param vector2 array with correct dimension.
330 | *
331 | * @return the similarity score.
332 | */
333 | public float computeSimilarity(float[] vector1, float[] vector2) {
334 | checkIndexIsInitialized();
335 | return hnswlib.computeSimilarity(reference, vector1, vector2);
336 | }
337 |
338 | /**
339 | * Retrieves the current M value.
340 | *
341 | * @return the M value.
342 | */
343 | public int getM(){
344 | checkIndexIsInitialized();
345 | return hnswlib.getM(reference);
346 | }
347 |
348 | /**
349 | * Retrieves the current Ef value.
350 | *
351 | * @return the EF value.
352 | */
353 | public int getEf(){
354 | checkIndexIsInitialized();
355 | return hnswlib.getEf(reference);
356 | }
357 |
358 | /**
359 | * Retrieves the current ef construction.
360 | *
361 | * @return the ef construction value.
362 | */
363 | public int getEfConstruction(){
364 | checkIndexIsInitialized();
365 | return hnswlib.getEfConstruction(reference);
366 | }
367 |
368 | /**
369 | * Marks an ID as deleted.
370 | *
371 | * @param id identifier.
372 | */
373 | public void markDeleted(int id){
374 | checkResultCode(hnswlib.markDeleted(reference, id));
375 | if (ids.contains(id)) {
376 | ids.remove(id);
377 | }
378 | }
379 |
380 | private void checkIndexIsInitialized() {
381 | if (!initialized) {
382 | throw new IndexNotInitializedException();
383 | }
384 | }
385 |
386 | /**
387 | * Util function that normalizes an array.
388 | *
389 | * @param array input.
390 | */
391 | public static strictfp void normalize(float [] array){
392 | int n = array.length;
393 | float norm = 0;
394 | for (float v : array) {
395 | norm += v * v;
396 | }
397 | norm = (float) (1.0f / (Math.sqrt(norm) + 1e-30f));
398 | for (int i = 0; i < n; i++) {
399 | array[i] = array[i] * norm;
400 | }
401 | }
402 |
403 | /**
404 | * This method returns a ConcurrentIndex instance which contains
405 | * the same items present in a Index object. The indexes will share
406 | * the same native pointer, so there will be no memory duplication.
407 | *
408 | * It is important to say that the Index class allows adding items is in
409 | * parallel, so the building time can be much slower. On the other hand,
410 | * ConcurrentIndex offers thread-safe methods for adding and querying, which
411 | * can be interesting for multi-threaded environment with online
412 | * updates. Via this method, you can get the best of the two worlds.
413 | *
414 | * @param index with the items added
415 | * @return a thread-safe index
416 | */
417 | public static Index synchronizedIndex(Index index) {
418 | Index concurrentIndex = new ConcurrentIndex(index.spaceName, index.dimension);
419 | concurrentIndex.reference = index.reference;
420 | concurrentIndex.cleared = index.cleared;
421 | concurrentIndex.initialized = index.initialized;
422 | concurrentIndex.setIds(index.getIds());
423 | index.referenceReused = true;
424 | return concurrentIndex;
425 | }
426 | }
427 |
--------------------------------------------------------------------------------
/hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/AbstractIndexTest.java:
--------------------------------------------------------------------------------
1 | package com.stepstone.search.hnswlib.jna;
2 |
3 | import com.stepstone.search.hnswlib.jna.exception.IndexAlreadyInitializedException;
4 | import com.stepstone.search.hnswlib.jna.exception.IndexNotInitializedException;
5 | import com.stepstone.search.hnswlib.jna.exception.ItemCannotBeInsertedIntoTheVectorSpaceException;
6 | import com.stepstone.search.hnswlib.jna.exception.OnceIndexIsClearedItCannotBeReusedException;
7 | import com.stepstone.search.hnswlib.jna.exception.QueryCannotReturnResultsException;
8 | import com.stepstone.search.hnswlib.jna.exception.UnexpectedNativeException;
9 | import org.junit.Test;
10 |
11 | import java.io.File;
12 | import java.io.IOException;
13 | import java.nio.file.Path;
14 | import java.nio.file.Paths;
15 | import java.util.Optional;
16 | import java.util.concurrent.ExecutorService;
17 | import java.util.concurrent.Executors;
18 | import java.util.concurrent.TimeUnit;
19 |
20 | import static org.junit.Assert.assertArrayEquals;
21 | import static org.junit.Assert.assertEquals;
22 | import static org.junit.Assert.assertFalse;
23 | import static org.junit.Assert.assertNotNull;
24 | import static org.junit.Assert.assertNull;
25 | import static org.junit.Assert.assertTrue;
26 |
27 | public abstract class AbstractIndexTest {
28 |
29 | protected abstract Index createIndexInstance(SpaceName spaceName, int dimensions);
30 |
31 | @Test
32 | public void testSingleIndexInstantiation() throws UnexpectedNativeException {
33 | Index i1 = createIndexInstance(SpaceName.IP, 30);
34 | assertNotNull(i1);
35 | i1.clear();
36 | }
37 |
38 | @Test
39 | public void testMultipleIndexInstantiation() throws UnexpectedNativeException {
40 | Index i1 = createIndexInstance(SpaceName.IP, 30);
41 | assertNotNull(i1);
42 | Index i2 = createIndexInstance(SpaceName.COSINE, 30);
43 | assertNotNull(i2);
44 | Index i3 = createIndexInstance(SpaceName.L2, 30);
45 | assertNotNull(i3);
46 | i1.clear();
47 | i2.clear();
48 | i3.clear();
49 | }
50 |
51 | @Test
52 | public void testIndexInitialization() throws UnexpectedNativeException {
53 | Index i1 = createIndexInstance(SpaceName.COSINE, 50);
54 | i1.initialize(500_000, 16, 200, 100);
55 | assertEquals(0, i1.getLength());
56 | i1.clear();
57 | }
58 |
59 | @Test
60 | public void testIndexInitialization2() throws UnexpectedNativeException {
61 | Index i1 = createIndexInstance(SpaceName.COSINE, 50);
62 | i1.initialize();
63 | assertEquals(0, i1.getLength());
64 | i1.clear();
65 | }
66 |
67 | @Test(expected = IndexAlreadyInitializedException.class)
68 | public void testIndexMultipleInitialization() throws UnexpectedNativeException {
69 | Index i1 = createIndexInstance(SpaceName.COSINE, 50);
70 | i1.initialize(500_000, 16, 200, 100);
71 | i1.initialize();
72 | }
73 |
74 | @Test
75 | public void testIndexAddItem() throws UnexpectedNativeException {
76 | Index i1 = createIndexInstance(SpaceName.COSINE, 3);
77 | i1.initialize(1);
78 | i1.addItem(new float[] { 1.3f, 1.2f, 1.5f }, 3);
79 | assertEquals(1, i1.getLength());
80 | i1.clear();
81 | }
82 |
83 | @Test
84 | public void testIndexAddItemIndependence() throws UnexpectedNativeException {
85 | testIndexAddItem();
86 | Index i2 = createIndexInstance(SpaceName.IP, 4);
87 | i2.initialize(3);
88 | assertEquals(0, i2.getLength());
89 | i2.clear();
90 | }
91 |
92 | @Test
93 | public void testIndexSaveAndLoad() throws UnexpectedNativeException, IOException {
94 | File tempFile = File.createTempFile("index", "sm");
95 | Path tempFilePath = Paths.get(tempFile.getAbsolutePath());
96 |
97 | Index i1 = createIndexInstance(SpaceName.COSINE, 3);
98 | i1.initialize(1);
99 | i1.addItem(new float[] { 1.3f, 1.2f, 1.5f }, 3);
100 | i1.save(tempFilePath);
101 | i1.clear();
102 |
103 | Index i2 = createIndexInstance(SpaceName.COSINE, 3);
104 | assertEquals(0, i2.getLength());
105 | i2.load(tempFilePath,1);
106 | assertEquals(1, i2.getLength());
107 | i2.clear();
108 |
109 | assertTrue(tempFile.delete());
110 | }
111 |
112 | @Test
113 | public void testParallelAddItemsInMultipleIndexes() throws InterruptedException, UnexpectedNativeException {
114 | int cpus = Runtime.getRuntime().availableProcessors();
115 | ExecutorService executorService = Executors.newFixedThreadPool(cpus);
116 |
117 | Index i1 = createIndexInstance(SpaceName.L2, 50);
118 | i1.initialize(1_050);
119 |
120 | Index i2 = createIndexInstance(SpaceName.COSINE, 50);
121 | i2.initialize(1_050);
122 |
123 | Runnable addItemIndex1 = () -> {
124 | try {
125 | i1.addItem(HnswlibTestUtils.getRandomFloatArray(50));
126 | } catch (UnexpectedNativeException e) {
127 | e.printStackTrace();
128 | }
129 | };
130 | Runnable addItemIndex2 = () -> {
131 | try {
132 | i2.addItem(HnswlibTestUtils.getRandomFloatArray(50));
133 | } catch (UnexpectedNativeException e) {
134 | e.printStackTrace();
135 | }
136 | };
137 |
138 | for(int i = 0; i < 1_000; i++) {
139 | executorService.submit(addItemIndex1);
140 | executorService.submit(addItemIndex2);
141 | }
142 |
143 | executorService.shutdown();
144 | executorService.awaitTermination(5, TimeUnit.MINUTES);
145 |
146 | assertEquals(1_000, i1.getLength());
147 | assertEquals(1_000, i2.getLength());
148 |
149 | i1.clear(); i2.clear();
150 | }
151 |
152 | @Test
153 | public void testConcurrentInsertQuery() throws InterruptedException, UnexpectedNativeException {
154 | ExecutorService executorService = Executors.newFixedThreadPool(50);
155 |
156 | Index i1 = createIndexInstance(SpaceName.L2, 50);
157 | i1.initialize(1_050);
158 |
159 | float[] randomFloatArray = HnswlibTestUtils.getRandomFloatArray(50);
160 |
161 | Runnable addItemIndex1 = () -> {
162 | try {
163 | i1.addItem(randomFloatArray);
164 | } catch (UnexpectedNativeException e) {
165 | e.printStackTrace();
166 | }
167 | };
168 |
169 | Runnable queryItemIndex1 = () -> {
170 | QueryTuple queryTuple;
171 | try {
172 | queryTuple = i1.knnQuery(randomFloatArray, 1);
173 | assertEquals(50, queryTuple.getIds().length);
174 | assertEquals(50, queryTuple.getCoefficients().length);
175 | } catch (UnexpectedNativeException e) {
176 | e.printStackTrace();
177 | }
178 | };
179 |
180 | for(int i = 0; i < 1_000; i++) {
181 | executorService.submit(addItemIndex1);
182 | executorService.submit(queryItemIndex1);
183 | }
184 |
185 | executorService.shutdown();
186 | executorService.awaitTermination(5, TimeUnit.MINUTES);
187 |
188 | assertEquals(1_000, i1.getLength());
189 | i1.clear();
190 | }
191 |
192 | @Test(expected = QueryCannotReturnResultsException.class)
193 | public void testQueryEmptyException() throws UnexpectedNativeException {
194 | Index idx = createIndexInstance(SpaceName.COSINE, 3);
195 | idx.initialize(300);
196 | QueryTuple queryTuple = idx.knnQuery(new float[] {1.3f, 1.4f, 1.5f}, 3);
197 | assertNull(queryTuple);
198 | }
199 |
200 | @Test
201 | public void testOverwritingAnItemInTheModel() throws UnexpectedNativeException {
202 | Index index = createIndexInstance(SpaceName.COSINE, 4);
203 | index.initialize(5);
204 |
205 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 1.0f}, 1);
206 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.95f}, 2);
207 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.9f}, 3);
208 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.85f}, 4);
209 |
210 | QueryTuple queryTuple = index.knnQuery(new float[] {1.0f, 1.0f, 1.0f, 1.0f}, 3);
211 | assertEquals(1, queryTuple.ids[0]);
212 | assertEquals(2, queryTuple.ids[1]);
213 | assertEquals(3, queryTuple.ids[2]);
214 |
215 | index.addItem(new float[] { 0.0f, 0.0f, 0.0f, 0.0f}, 2);
216 | queryTuple = index.knnQuery(new float[] {1.0f, 1.0f, 1.0f, 1.0f}, 3);
217 | assertEquals(1, queryTuple.ids[0]);
218 | assertEquals(3, queryTuple.ids[1]);
219 | assertEquals(4, queryTuple.ids[2]);
220 |
221 | index.clear();
222 | }
223 |
224 | @Test(expected = ItemCannotBeInsertedIntoTheVectorSpaceException.class)
225 | public void testIncludingMoreItemsThanPossible() throws UnexpectedNativeException {
226 | Index index = createIndexInstance(SpaceName.L2, 4);
227 | index.initialize(2);
228 |
229 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 1.0f}, 1);
230 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.95f}, 2);
231 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.9f}, 3);
232 | }
233 |
234 | @Test
235 | public void testNativeArrayNormalization() throws UnexpectedNativeException {
236 | Index index = createIndexInstance(SpaceName.COSINE, 7);
237 | index.initialize(20);
238 |
239 | float[] item1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
240 | index.addItem(item1); /* COSINE requires a normalized item. So, this input will be normalized (and modified) before being added to the index. */
241 | assertArrayEquals(new float[] {0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f}, item1, 0.000001f);
242 |
243 | float[] item2 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
244 | index.addNormalizedItem(item1); /* since we are using a add normalized method, nothing should happen here. */
245 | assertArrayEquals(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, item2,0.000001f);
246 |
247 | QueryTuple queryTuple = index.knnQuery(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, 1);
248 | assertEquals(-2.3841858E-7f, queryTuple.getCoefficients()[0], 0.00001);
249 |
250 | queryTuple = index.knnNormalizedQuery(item2, 1);
251 | assertEquals(-1.645751476f, queryTuple.getCoefficients()[0], 0.00001);
252 | }
253 |
254 | @Test
255 | public void testHostNormalization() {
256 | float[] item1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
257 | Index.normalize(item1);
258 | assertArrayEquals(new float[] {0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f}, item1, 0.000001f);
259 | }
260 |
261 | @Test
262 | public void testIndexCosineEqualsToIPWhenNormalized() throws UnexpectedNativeException {
263 | float[] i1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
264 | Index.normalize(i1);
265 | float[] i2 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f};
266 | Index.normalize(i2);
267 | float[] i3 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f};
268 | Index.normalize(i3);
269 |
270 | Index indexCosine = createIndexInstance(SpaceName.COSINE, 7);
271 | indexCosine.initialize(3);
272 | indexCosine.addNormalizedItem(i1, 1_111_111);
273 | indexCosine.addNormalizedItem(i2, 1_222_222);
274 | indexCosine.addNormalizedItem(i3, 1_333_333);
275 |
276 | Index indexIP = createIndexInstance(SpaceName.IP, 7);
277 | indexIP.initialize(3);
278 | indexIP.addNormalizedItem(i1, 1_111_111);
279 | indexIP.addNormalizedItem(i2, 1_222_222);
280 | indexIP.addNormalizedItem(i3, 1_333_333);
281 |
282 | float[] input = new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
283 | Index.normalize(input);
284 |
285 | QueryTuple cosineQT = indexCosine.knnNormalizedQuery(input, 3);
286 | QueryTuple ipQT = indexCosine.knnNormalizedQuery(input, 3);
287 |
288 | assertArrayEquals(cosineQT.getCoefficients(), ipQT.getCoefficients(), 0.000001f);
289 | assertArrayEquals(cosineQT.getIds(), ipQT.getIds());
290 |
291 | indexIP.clear();
292 | indexCosine.clear();
293 | }
294 |
295 | @Test
296 | public void testSimpleQueryOf5ElementsAndDimension7IP() throws UnexpectedNativeException {
297 | Index index = createIndexInstance(SpaceName.IP, 7);
298 | index.initialize(7);
299 |
300 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }, 5);
301 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f }, 6);
302 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 7);
303 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.85f }, 8);
304 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.8f },9);
305 |
306 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f };
307 | QueryTuple ipQT = index.knnQuery(input, 4);
308 |
309 | assertArrayEquals(new int[] {5, 6, 7, 8}, ipQT.getIds());
310 | assertArrayEquals(new float[] {-6.0f, -5.95f, -5.9f, -5.85f}, ipQT.getCoefficients(), 0.000001f);
311 | index.clear();
312 | }
313 |
314 | @Test
315 | public void testSimpleQueryOf5ElementsAndDimension7Cosine() throws UnexpectedNativeException {
316 | Index index = createIndexInstance(SpaceName.COSINE, 7);
317 | index.initialize(7);
318 |
319 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }, 14);
320 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f }, 13);
321 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 12);
322 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.85f }, 11);
323 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.8f },10);
324 |
325 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f };
326 | QueryTuple ipQT = index.knnQuery(input, 4);
327 |
328 | assertArrayEquals(new int[] {14, 13, 12, 11}, ipQT.getIds());
329 | assertArrayEquals(new float[] {-2.3841858E-7f, 1.552105E-4f, 6.2948465E-4f, 0.001435399f}, ipQT.getCoefficients(), 0.000001f);
330 | index.clear();
331 | }
332 |
333 | @Test
334 | public void testSimpleQueryOf5ElementsAndDimension7L2() throws UnexpectedNativeException {
335 | Index index = createIndexInstance(SpaceName.L2, 7);
336 | index.initialize(7);
337 |
338 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 48);
339 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.85f }, 10);
340 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f }, 35);
341 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.8f },1);
342 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }, 33);
343 |
344 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f };
345 | QueryTuple ipQT = index.knnQuery(input, 4);
346 |
347 | assertArrayEquals(new int[] {33, 35, 48, 10}, ipQT.getIds());
348 | assertArrayEquals(new float[] { 0.0f, 0.002500001f, 0.010000004f, 0.022499993f}, ipQT.getCoefficients(), 0.000001f);
349 | index.clear();
350 | }
351 |
352 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class)
353 | public void testDoubleClear() throws UnexpectedNativeException {
354 | Index idx = createIndexInstance(SpaceName.IP, 30);
355 | idx.initialize(3);
356 | idx.clear();
357 | idx.clear();
358 | }
359 |
360 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class)
361 | public void testUsageAfterClear1() throws UnexpectedNativeException {
362 | Index idx = createIndexInstance(SpaceName.IP, 30);
363 | idx.clear();
364 | idx.initialize(30);
365 | }
366 |
367 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class)
368 | public void testUsageAfterClear2() throws UnexpectedNativeException {
369 | Index index = createIndexInstance(SpaceName.IP, 30);
370 | index.initialize(30);
371 | index.clear();
372 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 48);
373 | }
374 |
375 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class)
376 | public void testUsageAfterClear3() throws UnexpectedNativeException {
377 | Index index = createIndexInstance(SpaceName.IP, 30);
378 | index.initialize(30);
379 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 48);
380 | index.clear();
381 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f };
382 | index.knnQuery(input, 4);
383 | }
384 |
385 | @Test
386 | public void testTryingDoubleClearDueToGCWhenReferenceIsLost() throws UnexpectedNativeException {
387 | Index index = createIndexInstance(SpaceName.IP, 30);
388 | index.initialize(30);
389 | index.clear();
390 | index = createIndexInstance(SpaceName.IP, 30);
391 | int counter = 10;
392 | while (counter-- > 0) {
393 | System.gc();
394 | }
395 | index.initialize(30);
396 | assertNotNull(index);
397 | }
398 |
399 | @Test
400 | public void testGetData() {
401 | Index index = createIndexInstance(SpaceName.COSINE, 3);
402 | index.initialize();
403 | float[] vector = {1F, 2F, 3F};
404 | index.addItem(vector);
405 | assertTrue(index.hasId(0));
406 | Optional data = index.getData(0);
407 | assertTrue(data.isPresent());
408 | assertArrayEquals(vector, data.get(), 0.0f);
409 | assertFalse(index.hasId(1));
410 | assertFalse(index.getData(1).isPresent());
411 |
412 | float[] vector2 = {1F, 2F, 3F};
413 | index.addItem(vector2, 1230);
414 | assertTrue(index.hasId(1230));
415 | assertFalse(index.hasId(1231));
416 |
417 | index.clear();
418 | assertFalse(index.hasId(1230));
419 | assertFalse(index.hasId(1231));
420 | }
421 |
422 | @Test
423 | public void testGetDataWhenIndexCleared() {
424 | Index index = createIndexInstance(SpaceName.COSINE, 3);
425 | index.initialize();
426 | index.clear();
427 | assertFalse(index.hasId(1202));
428 | Index index2 = createIndexInstance(SpaceName.COSINE, 1);
429 | assertFalse(index2.hasId(123));
430 | }
431 |
432 | @Test(expected = IndexNotInitializedException.class)
433 | public void testUseAddItemIndexWithoutInitialize() {
434 | Index index = createIndexInstance(SpaceName.COSINE, 1);
435 | index.addItem(new float[1]);
436 | }
437 |
438 | @Test(expected = IndexNotInitializedException.class)
439 | public void testUseAddNormalizedItemIndexWithoutInitialize() {
440 | Index index = createIndexInstance(SpaceName.COSINE, 1);
441 | index.addNormalizedItem(new float[1]);
442 | }
443 |
444 | @Test(expected = IndexNotInitializedException.class)
445 | public void testUseKnnQueryIndexWithoutInitialize() {
446 | Index index = createIndexInstance(SpaceName.COSINE, 1);
447 | index.knnQuery(new float[1],1);
448 | }
449 |
450 | @Test(expected = IndexNotInitializedException.class)
451 | public void testUseKnnNormalizedQueryQueryIndexWithoutInitialize() {
452 | Index index = createIndexInstance(SpaceName.COSINE, 1);
453 | index.knnNormalizedQuery(new float[1],1);
454 | }
455 |
456 | @Test(expected = IndexNotInitializedException.class)
457 | public void testGetMWithoutInitializeIndex() {
458 | Index index = createIndexInstance(SpaceName.COSINE, 1);
459 | index.getM();
460 | }
461 |
462 | @Test(expected = IndexNotInitializedException.class)
463 | public void testGetEfWithoutInitializeIndex() {
464 | Index index = createIndexInstance(SpaceName.COSINE, 1);
465 | index.getEf();
466 | }
467 |
468 | @Test(expected = IndexNotInitializedException.class)
469 | public void testGetEfConstructionWithoutInitializeIndex() {
470 | Index index = createIndexInstance(SpaceName.COSINE, 1);
471 | index.getEfConstruction();
472 | }
473 |
474 | @Test
475 | public void testMarkAsDeleted() {
476 | Index index = createIndexInstance(SpaceName.COSINE, 7);
477 | index.initialize(7);
478 |
479 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }, 14);
480 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f }, 13);
481 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 12);
482 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.85f }, 11);
483 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.8f },10);
484 |
485 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f };
486 | QueryTuple ipQT = index.knnQuery(input, 4);
487 |
488 | assertArrayEquals(new int[] {14, 13, 12, 11}, ipQT.getIds());
489 | assertArrayEquals(new float[] {-2.3841858E-7f, 1.552105E-4f, 6.2948465E-4f, 0.001435399f}, ipQT.getCoefficients(), 0.000001f);
490 |
491 | index.markDeleted(13);
492 | QueryTuple ipQT2 = index.knnQuery(input, 4);
493 | assertArrayEquals(new int[] {14, 12, 11, 10}, ipQT2.getIds());
494 | assertArrayEquals(new float[] {-2.3841858E-7f, 6.2948465E-4f, 0.001435399f, 0.0025851727f}, ipQT2.getCoefficients(), 0.000001f);
495 |
496 | index.markDeleted(12);
497 | QueryTuple ipQT3 = index.knnQuery(input, 3);
498 | assertArrayEquals(new int[] {14, 11, 10}, ipQT3.getIds());
499 | assertArrayEquals(new float[] {-2.3841858E-7f, 0.001435399f, 0.0025851727f}, ipQT3.getCoefficients(), 0.000001f);
500 |
501 | index.clear();
502 | }
503 |
504 | }
505 |
--------------------------------------------------------------------------------
/hnswlib/hnswalg.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "visited_list_pool.h"
4 | #include "hnswlib.h"
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | namespace hnswlib {
13 | typedef unsigned int tableint;
14 | typedef unsigned int linklistsizeint;
15 |
16 | template
17 | class HierarchicalNSW : public AlgorithmInterface {
18 | public:
19 | static const tableint max_update_element_locks = 65536;
20 | HierarchicalNSW(SpaceInterface *s) {
21 |
22 | }
23 |
24 | HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) {
25 | loadIndex(location, s, max_elements);
26 | }
27 |
28 | HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) :
29 | link_list_locks_(max_elements), element_levels_(max_elements), link_list_update_locks_(max_update_element_locks) {
30 | max_elements_ = max_elements;
31 |
32 | has_deletions_=false;
33 | data_size_ = s->get_data_size();
34 | fstdistfunc_ = s->get_dist_func();
35 | dist_func_param_ = s->get_dist_func_param();
36 | M_ = M;
37 | maxM_ = M_;
38 | maxM0_ = M_ * 2;
39 | ef_construction_ = std::max(ef_construction,M_);
40 | ef_ = 10;
41 |
42 | level_generator_.seed(random_seed);
43 | update_probability_generator_.seed(random_seed + 1);
44 |
45 | size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
46 | size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype);
47 | offsetData_ = size_links_level0_;
48 | label_offset_ = size_links_level0_ + data_size_;
49 | offsetLevel0_ = 0;
50 |
51 | data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_);
52 | if (data_level0_memory_ == nullptr)
53 | throw std::runtime_error("Not enough memory");
54 |
55 | cur_element_count = 0;
56 |
57 | visited_list_pool_ = new VisitedListPool(1, max_elements);
58 |
59 |
60 |
61 | //initializations for special treatment of the first node
62 | enterpoint_node_ = -1;
63 | maxlevel_ = -1;
64 |
65 | linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
66 | if (linkLists_ == nullptr)
67 | throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
68 | size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
69 | mult_ = 1 / log(1.0 * M_);
70 | revSize_ = 1.0 / mult_;
71 | }
72 |
73 | struct CompareByFirst {
74 | constexpr bool operator()(std::pair const &a,
75 | std::pair const &b) const noexcept {
76 | return a.first < b.first;
77 | }
78 | };
79 |
80 | ~HierarchicalNSW() {
81 |
82 | free(data_level0_memory_);
83 | for (tableint i = 0; i < cur_element_count; i++) {
84 | if (element_levels_[i] > 0)
85 | free(linkLists_[i]);
86 | }
87 | free(linkLists_);
88 | delete visited_list_pool_;
89 | }
90 |
91 | size_t max_elements_;
92 | size_t cur_element_count;
93 | size_t size_data_per_element_;
94 | size_t size_links_per_element_;
95 |
96 | size_t M_;
97 | size_t maxM_;
98 | size_t maxM0_;
99 | size_t ef_construction_;
100 |
101 | double mult_, revSize_;
102 | int maxlevel_;
103 |
104 |
105 | VisitedListPool *visited_list_pool_;
106 | std::mutex cur_element_count_guard_;
107 |
108 | std::vector link_list_locks_;
109 |
110 | // Locks to prevent race condition during update/insert of an element at same time.
111 | // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
112 | std::vector link_list_update_locks_;
113 | tableint enterpoint_node_;
114 |
115 |
116 | size_t size_links_level0_;
117 | size_t offsetData_, offsetLevel0_;
118 |
119 |
120 | char *data_level0_memory_;
121 | char **linkLists_;
122 | std::vector element_levels_;
123 |
124 | size_t data_size_;
125 |
126 | bool has_deletions_;
127 |
128 |
129 | size_t label_offset_;
130 | DISTFUNC fstdistfunc_;
131 | void *dist_func_param_;
132 | std::unordered_map label_lookup_;
133 |
134 | std::default_random_engine level_generator_;
135 | std::default_random_engine update_probability_generator_;
136 |
137 | inline labeltype getExternalLabel(tableint internal_id) const {
138 | labeltype return_label;
139 | memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
140 | return return_label;
141 | }
142 |
143 | inline void setExternalLabel(tableint internal_id, labeltype label) const {
144 | memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
145 | }
146 |
147 | inline labeltype *getExternalLabeLp(tableint internal_id) const {
148 | return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
149 | }
150 |
151 | inline char *getDataByInternalId(tableint internal_id) const {
152 | return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
153 | }
154 |
155 | int getRandomLevel(double reverse_size) {
156 | std::uniform_real_distribution distribution(0.0, 1.0);
157 | double r = -log(distribution(level_generator_)) * reverse_size;
158 | return (int) r;
159 | }
160 |
161 |
162 | std::priority_queue, std::vector>, CompareByFirst>
163 | searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
164 | VisitedList *vl = visited_list_pool_->getFreeVisitedList();
165 | vl_type *visited_array = vl->mass;
166 | vl_type visited_array_tag = vl->curV;
167 |
168 | std::priority_queue, std::vector>, CompareByFirst> top_candidates;
169 | std::priority_queue, std::vector>, CompareByFirst> candidateSet;
170 |
171 | dist_t lowerBound;
172 | if (!isMarkedDeleted(ep_id)) {
173 | dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
174 | top_candidates.emplace(dist, ep_id);
175 | lowerBound = dist;
176 | candidateSet.emplace(-dist, ep_id);
177 | } else {
178 | lowerBound = std::numeric_limits::max();
179 | candidateSet.emplace(-lowerBound, ep_id);
180 | }
181 | visited_array[ep_id] = visited_array_tag;
182 |
183 | while (!candidateSet.empty()) {
184 | std::pair curr_el_pair = candidateSet.top();
185 | if ((-curr_el_pair.first) > lowerBound) {
186 | break;
187 | }
188 | candidateSet.pop();
189 |
190 | tableint curNodeNum = curr_el_pair.second;
191 |
192 | std::unique_lock lock(link_list_locks_[curNodeNum]);
193 |
194 | int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
195 | if (layer == 0) {
196 | data = (int*)get_linklist0(curNodeNum);
197 | } else {
198 | data = (int*)get_linklist(curNodeNum, layer);
199 | // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
200 | }
201 | size_t size = getListCount((linklistsizeint*)data);
202 | tableint *datal = (tableint *) (data + 1);
203 | #ifdef USE_SSE
204 | _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
205 | _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
206 | _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
207 | _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
208 | #endif
209 |
210 | for (size_t j = 0; j < size; j++) {
211 | tableint candidate_id = *(datal + j);
212 | // if (candidate_id == 0) continue;
213 | #ifdef USE_SSE
214 | _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
215 | _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
216 | #endif
217 | if (visited_array[candidate_id] == visited_array_tag) continue;
218 | visited_array[candidate_id] = visited_array_tag;
219 | char *currObj1 = (getDataByInternalId(candidate_id));
220 |
221 | dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
222 | if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
223 | candidateSet.emplace(-dist1, candidate_id);
224 | #ifdef USE_SSE
225 | _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
226 | #endif
227 |
228 | if (!isMarkedDeleted(candidate_id))
229 | top_candidates.emplace(dist1, candidate_id);
230 |
231 | if (top_candidates.size() > ef_construction_)
232 | top_candidates.pop();
233 |
234 | if (!top_candidates.empty())
235 | lowerBound = top_candidates.top().first;
236 | }
237 | }
238 | }
239 | visited_list_pool_->releaseVisitedList(vl);
240 |
241 | return top_candidates;
242 | }
243 |
244 | mutable std::atomic metric_distance_computations;
245 | mutable std::atomic metric_hops;
246 |
247 | template
248 | std::priority_queue, std::vector>, CompareByFirst>
249 | searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
250 | VisitedList *vl = visited_list_pool_->getFreeVisitedList();
251 | vl_type *visited_array = vl->mass;
252 | vl_type visited_array_tag = vl->curV;
253 |
254 | std::priority_queue, std::vector>, CompareByFirst> top_candidates;
255 | std::priority_queue, std::vector>, CompareByFirst> candidate_set;
256 |
257 | dist_t lowerBound;
258 | if (!has_deletions || !isMarkedDeleted(ep_id)) {
259 | dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
260 | lowerBound = dist;
261 | top_candidates.emplace(dist, ep_id);
262 | candidate_set.emplace(-dist, ep_id);
263 | } else {
264 | lowerBound = std::numeric_limits::max();
265 | candidate_set.emplace(-lowerBound, ep_id);
266 | }
267 |
268 | visited_array[ep_id] = visited_array_tag;
269 |
270 | while (!candidate_set.empty()) {
271 |
272 | std::pair current_node_pair = candidate_set.top();
273 |
274 | if ((-current_node_pair.first) > lowerBound) {
275 | break;
276 | }
277 | candidate_set.pop();
278 |
279 | tableint current_node_id = current_node_pair.second;
280 | int *data = (int *) get_linklist0(current_node_id);
281 | size_t size = getListCount((linklistsizeint*)data);
282 | // bool cur_node_deleted = isMarkedDeleted(current_node_id);
283 | if(collect_metrics){
284 | metric_hops++;
285 | metric_distance_computations+=size;
286 | }
287 |
288 | #ifdef USE_SSE
289 | _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
290 | _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
291 | _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
292 | _mm_prefetch((char *) (data + 2), _MM_HINT_T0);
293 | #endif
294 |
295 | for (size_t j = 1; j <= size; j++) {
296 | int candidate_id = *(data + j);
297 | // if (candidate_id == 0) continue;
298 | #ifdef USE_SSE
299 | _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
300 | _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
301 | _MM_HINT_T0);////////////
302 | #endif
303 | if (!(visited_array[candidate_id] == visited_array_tag)) {
304 |
305 | visited_array[candidate_id] = visited_array_tag;
306 |
307 | char *currObj1 = (getDataByInternalId(candidate_id));
308 | dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);
309 |
310 | if (top_candidates.size() < ef || lowerBound > dist) {
311 | candidate_set.emplace(-dist, candidate_id);
312 | #ifdef USE_SSE
313 | _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
314 | offsetLevel0_,///////////
315 | _MM_HINT_T0);////////////////////////
316 | #endif
317 |
318 | if (!has_deletions || !isMarkedDeleted(candidate_id))
319 | top_candidates.emplace(dist, candidate_id);
320 |
321 | if (top_candidates.size() > ef)
322 | top_candidates.pop();
323 |
324 | if (!top_candidates.empty())
325 | lowerBound = top_candidates.top().first;
326 | }
327 | }
328 | }
329 | }
330 |
331 | visited_list_pool_->releaseVisitedList(vl);
332 | return top_candidates;
333 | }
334 |
335 | void getNeighborsByHeuristic2(
336 | std::priority_queue, std::vector>, CompareByFirst> &top_candidates,
337 | const size_t M) {
338 | if (top_candidates.size() < M) {
339 | return;
340 | }
341 |
342 | std::priority_queue> queue_closest;
343 | std::vector> return_list;
344 | while (top_candidates.size() > 0) {
345 | queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
346 | top_candidates.pop();
347 | }
348 |
349 | while (queue_closest.size()) {
350 | if (return_list.size() >= M)
351 | break;
352 | std::pair curent_pair = queue_closest.top();
353 | dist_t dist_to_query = -curent_pair.first;
354 | queue_closest.pop();
355 | bool good = true;
356 |
357 | for (std::pair second_pair : return_list) {
358 | dist_t curdist =
359 | fstdistfunc_(getDataByInternalId(second_pair.second),
360 | getDataByInternalId(curent_pair.second),
361 | dist_func_param_);;
362 | if (curdist < dist_to_query) {
363 | good = false;
364 | break;
365 | }
366 | }
367 | if (good) {
368 | return_list.push_back(curent_pair);
369 | }
370 | }
371 |
372 | for (std::pair curent_pair : return_list) {
373 | top_candidates.emplace(-curent_pair.first, curent_pair.second);
374 | }
375 | }
376 |
377 |
378 | linklistsizeint *get_linklist0(tableint internal_id) const {
379 | return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
380 | };
381 |
382 | linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const {
383 | return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_);
384 | };
385 |
386 | linklistsizeint *get_linklist(tableint internal_id, int level) const {
387 | return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
388 | };
389 |
390 | linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const {
391 | return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level);
392 | };
393 |
394 | tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c,
395 | std::priority_queue, std::vector>, CompareByFirst> &top_candidates,
396 | int level, bool isUpdate) {
397 | size_t Mcurmax = level ? maxM_ : maxM0_;
398 | getNeighborsByHeuristic2(top_candidates, M_);
399 | if (top_candidates.size() > M_)
400 | throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
401 |
402 | std::vector selectedNeighbors;
403 | selectedNeighbors.reserve(M_);
404 | while (top_candidates.size() > 0) {
405 | selectedNeighbors.push_back(top_candidates.top().second);
406 | top_candidates.pop();
407 | }
408 |
409 | tableint next_closest_entry_point = selectedNeighbors[0];
410 |
411 | {
412 | linklistsizeint *ll_cur;
413 | if (level == 0)
414 | ll_cur = get_linklist0(cur_c);
415 | else
416 | ll_cur = get_linklist(cur_c, level);
417 |
418 | if (*ll_cur && !isUpdate) {
419 | throw std::runtime_error("The newly inserted element should have blank link list");
420 | }
421 | setListCount(ll_cur,selectedNeighbors.size());
422 | tableint *data = (tableint *) (ll_cur + 1);
423 | for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
424 | if (data[idx] && !isUpdate)
425 | throw std::runtime_error("Possible memory corruption");
426 | if (level > element_levels_[selectedNeighbors[idx]])
427 | throw std::runtime_error("Trying to make a link on a non-existent level");
428 |
429 | data[idx] = selectedNeighbors[idx];
430 |
431 | }
432 | }
433 |
434 | for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
435 |
436 | std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]);
437 |
438 | linklistsizeint *ll_other;
439 | if (level == 0)
440 | ll_other = get_linklist0(selectedNeighbors[idx]);
441 | else
442 | ll_other = get_linklist(selectedNeighbors[idx], level);
443 |
444 | size_t sz_link_list_other = getListCount(ll_other);
445 |
446 | if (sz_link_list_other > Mcurmax)
447 | throw std::runtime_error("Bad value of sz_link_list_other");
448 | if (selectedNeighbors[idx] == cur_c)
449 | throw std::runtime_error("Trying to connect an element to itself");
450 | if (level > element_levels_[selectedNeighbors[idx]])
451 | throw std::runtime_error("Trying to make a link on a non-existent level");
452 |
453 | tableint *data = (tableint *) (ll_other + 1);
454 |
455 | bool is_cur_c_present = false;
456 | if (isUpdate) {
457 | for (size_t j = 0; j < sz_link_list_other; j++) {
458 | if (data[j] == cur_c) {
459 | is_cur_c_present = true;
460 | break;
461 | }
462 | }
463 | }
464 |
465 | // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
466 | if (!is_cur_c_present) {
467 | if (sz_link_list_other < Mcurmax) {
468 | data[sz_link_list_other] = cur_c;
469 | setListCount(ll_other, sz_link_list_other + 1);
470 | } else {
471 | // finding the "weakest" element to replace it with the new one
472 | dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
473 | dist_func_param_);
474 | // Heuristic:
475 | std::priority_queue, std::vector>, CompareByFirst> candidates;
476 | candidates.emplace(d_max, cur_c);
477 |
478 | for (size_t j = 0; j < sz_link_list_other; j++) {
479 | candidates.emplace(
480 | fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
481 | dist_func_param_), data[j]);
482 | }
483 |
484 | getNeighborsByHeuristic2(candidates, Mcurmax);
485 |
486 | int indx = 0;
487 | while (candidates.size() > 0) {
488 | data[indx] = candidates.top().second;
489 | candidates.pop();
490 | indx++;
491 | }
492 |
493 | setListCount(ll_other, indx);
494 | // Nearest K:
495 | /*int indx = -1;
496 | for (int j = 0; j < sz_link_list_other; j++) {
497 | dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
498 | if (d > d_max) {
499 | indx = j;
500 | d_max = d;
501 | }
502 | }
503 | if (indx >= 0) {
504 | data[indx] = cur_c;
505 | } */
506 | }
507 | }
508 | }
509 |
510 | return next_closest_entry_point;
511 | }
512 |
513 | std::mutex global;
514 | size_t ef_;
515 |
516 | void setEf(size_t ef) {
517 | ef_ = ef;
518 | }
519 |
520 |
521 | std::priority_queue> searchKnnInternal(void *query_data, int k) {
522 | std::priority_queue> top_candidates;
523 | if (cur_element_count == 0) return top_candidates;
524 | tableint currObj = enterpoint_node_;
525 | dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
526 |
527 | for (size_t level = maxlevel_; level > 0; level--) {
528 | bool changed = true;
529 | while (changed) {
530 | changed = false;
531 | int *data;
532 | data = (int *) get_linklist(currObj,level);
533 | int size = getListCount(data);
534 | tableint *datal = (tableint *) (data + 1);
535 | for (int i = 0; i < size; i++) {
536 | tableint cand = datal[i];
537 | if (cand < 0 || cand > max_elements_)
538 | throw std::runtime_error("cand error");
539 | dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
540 |
541 | if (d < curdist) {
542 | curdist = d;
543 | currObj = cand;
544 | changed = true;
545 | }
546 | }
547 | }
548 | }
549 |
550 | if (has_deletions_) {
551 | std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data,
552 | ef_);
553 | top_candidates.swap(top_candidates1);
554 | }
555 | else{
556 | std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data,
557 | ef_);
558 | top_candidates.swap(top_candidates1);
559 | }
560 |
561 | while (top_candidates.size() > k) {
562 | top_candidates.pop();
563 | }
564 | return top_candidates;
565 | };
566 |
567 | void resizeIndex(size_t new_max_elements){
568 | if (new_max_elements(new_max_elements).swap(link_list_locks_);
580 |
581 | // Reallocate base layer
582 | char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_);
583 | if (data_level0_memory_new == nullptr)
584 | throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
585 | memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_);
586 | free(data_level0_memory_);
587 | data_level0_memory_=data_level0_memory_new;
588 |
589 | // Reallocate all other layers
590 | char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements);
591 | if (linkLists_new == nullptr)
592 | throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
593 | memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *));
594 | free(linkLists_);
595 | linkLists_=linkLists_new;
596 |
597 | max_elements_=new_max_elements;
598 |
599 | }
600 |
601 | void saveIndex(const std::string &location) {
602 | std::ofstream output(location, std::ios::binary);
603 | std::streampos position;
604 |
605 | writeBinaryPOD(output, offsetLevel0_);
606 | writeBinaryPOD(output, max_elements_);
607 | writeBinaryPOD(output, cur_element_count);
608 | writeBinaryPOD(output, size_data_per_element_);
609 | writeBinaryPOD(output, label_offset_);
610 | writeBinaryPOD(output, offsetData_);
611 | writeBinaryPOD(output, maxlevel_);
612 | writeBinaryPOD(output, enterpoint_node_);
613 | writeBinaryPOD(output, maxM_);
614 |
615 | writeBinaryPOD(output, maxM0_);
616 | writeBinaryPOD(output, M_);
617 | writeBinaryPOD(output, mult_);
618 | writeBinaryPOD(output, ef_construction_);
619 |
620 | output.write(data_level0_memory_, cur_element_count * size_data_per_element_);
621 |
622 | for (size_t i = 0; i < cur_element_count; i++) {
623 | unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0;
624 | writeBinaryPOD(output, linkListSize);
625 | if (linkListSize)
626 | output.write(linkLists_[i], linkListSize);
627 | }
628 | output.close();
629 | }
630 |
631 | void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) {
632 |
633 |
634 | std::ifstream input(location, std::ios::binary);
635 |
636 | if (!input.is_open())
637 | throw std::runtime_error("Cannot open file");
638 |
639 |
640 | // get file size:
641 | input.seekg(0,input.end);
642 | std::streampos total_filesize=input.tellg();
643 | input.seekg(0,input.beg);
644 |
645 | readBinaryPOD(input, offsetLevel0_);
646 | readBinaryPOD(input, max_elements_);
647 | readBinaryPOD(input, cur_element_count);
648 |
649 | size_t max_elements=max_elements_i;
650 | if(max_elements < cur_element_count)
651 | max_elements = max_elements_;
652 | max_elements_ = max_elements;
653 | readBinaryPOD(input, size_data_per_element_);
654 | readBinaryPOD(input, label_offset_);
655 | readBinaryPOD(input, offsetData_);
656 | readBinaryPOD(input, maxlevel_);
657 | readBinaryPOD(input, enterpoint_node_);
658 |
659 | readBinaryPOD(input, maxM_);
660 | readBinaryPOD(input, maxM0_);
661 | readBinaryPOD(input, M_);
662 | readBinaryPOD(input, mult_);
663 | readBinaryPOD(input, ef_construction_);
664 |
665 |
666 | data_size_ = s->get_data_size();
667 | fstdistfunc_ = s->get_dist_func();
668 | dist_func_param_ = s->get_dist_func_param();
669 |
670 | auto pos=input.tellg();
671 |
672 |
673 | /// Optional - check if index is ok:
674 |
675 | input.seekg(cur_element_count * size_data_per_element_,input.cur);
676 | for (size_t i = 0; i < cur_element_count; i++) {
677 | if(input.tellg() < 0 || input.tellg()>=total_filesize){
678 | throw std::runtime_error("Index seems to be corrupted or unsupported");
679 | }
680 |
681 | unsigned int linkListSize;
682 | readBinaryPOD(input, linkListSize);
683 | if (linkListSize != 0) {
684 | input.seekg(linkListSize,input.cur);
685 | }
686 | }
687 |
688 | // throw exception if it either corrupted or old index
689 | if(input.tellg()!=total_filesize)
690 | throw std::runtime_error("Index seems to be corrupted or unsupported");
691 |
692 | input.clear();
693 |
694 | /// Optional check end
695 |
696 | input.seekg(pos,input.beg);
697 |
698 |
699 | data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
700 | if (data_level0_memory_ == nullptr)
701 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
702 | input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
703 |
704 |
705 |
706 |
707 | size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
708 |
709 |
710 | size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
711 | std::vector(max_elements).swap(link_list_locks_);
712 | std::vector(max_update_element_locks).swap(link_list_update_locks_);
713 |
714 |
715 | visited_list_pool_ = new VisitedListPool(1, max_elements);
716 |
717 |
718 | linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
719 | if (linkLists_ == nullptr)
720 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
721 | element_levels_ = std::vector(max_elements);
722 | revSize_ = 1.0 / mult_;
723 | ef_ = 10;
724 | for (size_t i = 0; i < cur_element_count; i++) {
725 | label_lookup_[getExternalLabel(i)]=i;
726 | unsigned int linkListSize;
727 | readBinaryPOD(input, linkListSize);
728 | if (linkListSize == 0) {
729 | element_levels_[i] = 0;
730 |
731 | linkLists_[i] = nullptr;
732 | } else {
733 | element_levels_[i] = linkListSize / size_links_per_element_;
734 | linkLists_[i] = (char *) malloc(linkListSize);
735 | if (linkLists_[i] == nullptr)
736 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
737 | input.read(linkLists_[i], linkListSize);
738 | }
739 | }
740 |
741 | has_deletions_=false;
742 |
743 | for (size_t i = 0; i < cur_element_count; i++) {
744 | if(isMarkedDeleted(i))
745 | has_deletions_=true;
746 | }
747 |
748 | input.close();
749 |
750 | return;
751 | }
752 |
753 | template
754 | std::vector getDataByLabel(labeltype label)
755 | {
756 | tableint label_c;
757 | auto search = label_lookup_.find(label);
758 | if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
759 | throw std::runtime_error("Label not found");
760 | }
761 | label_c = search->second;
762 |
763 | char* data_ptrv = getDataByInternalId(label_c);
764 | size_t dim = *((size_t *) dist_func_param_);
765 | std::vector data;
766 | data_t* data_ptr = (data_t*) data_ptrv;
767 | for (int i = 0; i < dim; i++) {
768 | data.push_back(*data_ptr);
769 | data_ptr += 1;
770 | }
771 | return data;
772 | }
773 |
774 | static const unsigned char DELETE_MARK = 0x01;
775 | // static const unsigned char REUSE_MARK = 0x10;
776 | /**
777 | * Marks an element with the given label deleted, does NOT really change the current graph.
778 | * @param label
779 | */
780 | void markDelete(labeltype label)
781 | {
782 | has_deletions_=true;
783 | auto search = label_lookup_.find(label);
784 | if (search == label_lookup_.end()) {
785 | throw std::runtime_error("Label not found");
786 | }
787 | markDeletedInternal(search->second);
788 | }
789 |
790 | /**
791 | * Uses the first 8 bits of the memory for the linked list to store the mark,
792 | * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases.
793 | * @param internalId
794 | */
795 | void markDeletedInternal(tableint internalId) {
796 | unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
797 | *ll_cur |= DELETE_MARK;
798 | }
799 |
800 | /**
801 | * Remove the deleted mark of the node.
802 | * @param internalId
803 | */
804 | void unmarkDeletedInternal(tableint internalId) {
805 | unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
806 | *ll_cur &= ~DELETE_MARK;
807 | }
808 |
809 | /**
810 | * Checks the first 8 bits of the memory to see if the element is marked deleted.
811 | * @param internalId
812 | * @return
813 | */
814 | bool isMarkedDeleted(tableint internalId) const {
815 | unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2;
816 | return *ll_cur & DELETE_MARK;
817 | }
818 |
819 | unsigned short int getListCount(linklistsizeint * ptr) const {
820 | return *((unsigned short int *)ptr);
821 | }
822 |
823 | void setListCount(linklistsizeint * ptr, unsigned short int size) const {
824 | *((unsigned short int*)(ptr))=*((unsigned short int *)&size);
825 | }
826 |
827 | void addPoint(const void *data_point, labeltype label) {
828 | addPoint(data_point, label,-1);
829 | }
830 |
831 | void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) {
832 | // update the feature vector associated with existing point with new vector
833 | memcpy(getDataByInternalId(internalId), dataPoint, data_size_);
834 |
835 | int maxLevelCopy = maxlevel_;
836 | tableint entryPointCopy = enterpoint_node_;
837 | // If point to be updated is entry point and graph just contains single element then just return.
838 | if (entryPointCopy == internalId && cur_element_count == 1)
839 | return;
840 |
841 | int elemLevel = element_levels_[internalId];
842 | std::uniform_real_distribution distribution(0.0, 1.0);
843 | for (int layer = 0; layer <= elemLevel; layer++) {
844 | std::unordered_set sCand;
845 | std::unordered_set sNeigh;
846 | std::vector listOneHop = getConnectionsWithLock(internalId, layer);
847 | if (listOneHop.size() == 0)
848 | continue;
849 |
850 | sCand.insert(internalId);
851 |
852 | for (auto&& elOneHop : listOneHop) {
853 | sCand.insert(elOneHop);
854 |
855 | if (distribution(update_probability_generator_) > updateNeighborProbability)
856 | continue;
857 |
858 | sNeigh.insert(elOneHop);
859 |
860 | std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer);
861 | for (auto&& elTwoHop : listTwoHop) {
862 | sCand.insert(elTwoHop);
863 | }
864 | }
865 |
866 | for (auto&& neigh : sNeigh) {
867 | // if (neigh == internalId)
868 | // continue;
869 |
870 | std::priority_queue, std::vector>, CompareByFirst> candidates;
871 | int size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1;
872 | int elementsToKeep = std::min(int(ef_construction_), size);
873 | for (auto&& cand : sCand) {
874 | if (cand == neigh)
875 | continue;
876 |
877 | dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_);
878 | if (candidates.size() < elementsToKeep) {
879 | candidates.emplace(distance, cand);
880 | } else {
881 | if (distance < candidates.top().first) {
882 | candidates.pop();
883 | candidates.emplace(distance, cand);
884 | }
885 | }
886 | }
887 |
888 | // Retrieve neighbours using heuristic and set connections.
889 | getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_);
890 |
891 | {
892 | std::unique_lock lock(link_list_locks_[neigh]);
893 | linklistsizeint *ll_cur;
894 | ll_cur = get_linklist_at_level(neigh, layer);
895 | int candSize = candidates.size();
896 | setListCount(ll_cur, candSize);
897 | tableint *data = (tableint *) (ll_cur + 1);
898 | for (size_t idx = 0; idx < candSize; idx++) {
899 | data[idx] = candidates.top().second;
900 | candidates.pop();
901 | }
902 | }
903 | }
904 | }
905 |
906 | repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy);
907 | };
908 |
909 | void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) {
910 | tableint currObj = entryPointInternalId;
911 | if (dataPointLevel < maxLevel) {
912 | dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_);
913 | for (int level = maxLevel; level > dataPointLevel; level--) {
914 | bool changed = true;
915 | while (changed) {
916 | changed = false;
917 | unsigned int *data;
918 | std::unique_lock lock(link_list_locks_[currObj]);
919 | data = get_linklist_at_level(currObj,level);
920 | int size = getListCount(data);
921 | tableint *datal = (tableint *) (data + 1);
922 | #ifdef USE_SSE
923 | _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
924 | #endif
925 | for (int i = 0; i < size; i++) {
926 | #ifdef USE_SSE
927 | _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
928 | #endif
929 | tableint cand = datal[i];
930 | dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
931 | if (d < curdist) {
932 | curdist = d;
933 | currObj = cand;
934 | changed = true;
935 | }
936 | }
937 | }
938 | }
939 | }
940 |
941 | if (dataPointLevel > maxLevel)
942 | throw std::runtime_error("Level of item to be updated cannot be bigger than max level");
943 |
944 | for (int level = dataPointLevel; level >= 0; level--) {
945 | std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer(
946 | currObj, dataPoint, level);
947 |
948 | std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates;
949 | while (topCandidates.size() > 0) {
950 | if (topCandidates.top().second != dataPointInternalId)
951 | filteredTopCandidates.push(topCandidates.top());
952 |
953 | topCandidates.pop();
954 | }
955 |
956 | // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself.
957 | // To prevent self loops, the `topCandidates` is filtered and thus can be empty.
958 | if (filteredTopCandidates.size() > 0) {
959 | bool epDeleted = isMarkedDeleted(entryPointInternalId);
960 | if (epDeleted) {
961 | filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId);
962 | if (filteredTopCandidates.size() > ef_construction_)
963 | filteredTopCandidates.pop();
964 | }
965 |
966 | currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true);
967 | }
968 | }
969 | }
970 |
971 | std::vector getConnectionsWithLock(tableint internalId, int level) {
972 | std::unique_lock lock(link_list_locks_[internalId]);
973 | unsigned int *data = get_linklist_at_level(internalId, level);
974 | int size = getListCount(data);
975 | std::vector result(size);
976 | tableint *ll = (tableint *) (data + 1);
977 | memcpy(result.data(), ll,size * sizeof(tableint));
978 | return result;
979 | };
980 |
981 | tableint addPoint(const void *data_point, labeltype label, int level) {
982 |
983 | tableint cur_c = 0;
984 | {
985 | // Checking if the element with the same label already exists
986 | // if so, updating it *instead* of creating a new element.
987 | std::unique_lock templock_curr(cur_element_count_guard_);
988 | auto search = label_lookup_.find(label);
989 | if (search != label_lookup_.end()) {
990 | tableint existingInternalId = search->second;
991 |
992 | templock_curr.unlock();
993 |
994 | std::unique_lock lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
995 | updatePoint(data_point, existingInternalId, 1.0);
996 | return existingInternalId;
997 | }
998 |
999 | if (cur_element_count >= max_elements_) {
1000 | throw std::runtime_error("The number of elements exceeds the specified limit");
1001 | };
1002 |
1003 | cur_c = cur_element_count;
1004 | cur_element_count++;
1005 | label_lookup_[label] = cur_c;
1006 | }
1007 |
1008 | // Take update lock to prevent race conditions on an element with insertion/update at the same time.
1009 | std::unique_lock lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]);
1010 | std::unique_lock lock_el(link_list_locks_[cur_c]);
1011 | int curlevel = getRandomLevel(mult_);
1012 | if (level > 0)
1013 | curlevel = level;
1014 |
1015 | element_levels_[cur_c] = curlevel;
1016 |
1017 |
1018 | std::unique_lock templock(global);
1019 | int maxlevelcopy = maxlevel_;
1020 | if (curlevel <= maxlevelcopy)
1021 | templock.unlock();
1022 | tableint currObj = enterpoint_node_;
1023 | tableint enterpoint_copy = enterpoint_node_;
1024 |
1025 |
1026 | memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);
1027 |
1028 | // Initialisation of the data and label
1029 | memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
1030 | memcpy(getDataByInternalId(cur_c), data_point, data_size_);
1031 |
1032 |
1033 | if (curlevel) {
1034 | linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
1035 | if (linkLists_[cur_c] == nullptr)
1036 | throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
1037 | memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
1038 | }
1039 |
1040 | if ((signed)currObj != -1) {
1041 |
1042 | if (curlevel < maxlevelcopy) {
1043 |
1044 | dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1045 | for (int level = maxlevelcopy; level > curlevel; level--) {
1046 |
1047 |
1048 | bool changed = true;
1049 | while (changed) {
1050 | changed = false;
1051 | unsigned int *data;
1052 | std::unique_lock lock(link_list_locks_[currObj]);
1053 | data = get_linklist(currObj,level);
1054 | int size = getListCount(data);
1055 |
1056 | tableint *datal = (tableint *) (data + 1);
1057 | for (int i = 0; i < size; i++) {
1058 | tableint cand = datal[i];
1059 | if (cand < 0 || cand > max_elements_)
1060 | throw std::runtime_error("cand error");
1061 | dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1062 | if (d < curdist) {
1063 | curdist = d;
1064 | currObj = cand;
1065 | changed = true;
1066 | }
1067 | }
1068 | }
1069 | }
1070 | }
1071 |
1072 | bool epDeleted = isMarkedDeleted(enterpoint_copy);
1073 | for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
1074 | if (level > maxlevelcopy || level < 0) // possible?
1075 | throw std::runtime_error("Level error");
1076 |
1077 | std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer(
1078 | currObj, data_point, level);
1079 | if (epDeleted) {
1080 | top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
1081 | if (top_candidates.size() > ef_construction_)
1082 | top_candidates.pop();
1083 | }
1084 | currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
1085 | }
1086 |
1087 |
1088 | } else {
1089 | // Do nothing for the first element
1090 | enterpoint_node_ = 0;
1091 | maxlevel_ = curlevel;
1092 |
1093 | }
1094 |
1095 | //Releasing lock for the maximum level
1096 | if (curlevel > maxlevelcopy) {
1097 | enterpoint_node_ = cur_c;
1098 | maxlevel_ = curlevel;
1099 | }
1100 | return cur_c;
1101 | };
1102 |
1103 | std::priority_queue>
1104 | searchKnn(const void *query_data, size_t k) const {
1105 | std::priority_queue> result;
1106 | if (cur_element_count == 0) return result;
1107 |
1108 | tableint currObj = enterpoint_node_;
1109 | dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
1110 |
1111 | for (int level = maxlevel_; level > 0; level--) {
1112 | bool changed = true;
1113 | while (changed) {
1114 | changed = false;
1115 | unsigned int *data;
1116 |
1117 | data = (unsigned int *) get_linklist(currObj, level);
1118 | int size = getListCount(data);
1119 | metric_hops++;
1120 | metric_distance_computations+=size;
1121 |
1122 | tableint *datal = (tableint *) (data + 1);
1123 | for (int i = 0; i < size; i++) {
1124 | tableint cand = datal[i];
1125 | if (cand < 0 || cand > max_elements_)
1126 | throw std::runtime_error("cand error");
1127 | dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);
1128 |
1129 | if (d < curdist) {
1130 | curdist = d;
1131 | currObj = cand;
1132 | changed = true;
1133 | }
1134 | }
1135 | }
1136 | }
1137 |
1138 | std::priority_queue, std::vector>, CompareByFirst> top_candidates;
1139 | if (has_deletions_) {
1140 | top_candidates=searchBaseLayerST(
1141 | currObj, query_data, std::max(ef_, k));
1142 | }
1143 | else{
1144 | top_candidates=searchBaseLayerST(
1145 | currObj, query_data, std::max(ef_, k));
1146 | }
1147 |
1148 | while (top_candidates.size() > k) {
1149 | top_candidates.pop();
1150 | }
1151 | while (top_candidates.size() > 0) {
1152 | std::pair rez = top_candidates.top();
1153 | result.push(std::pair(rez.first, getExternalLabel(rez.second)));
1154 | top_candidates.pop();
1155 | }
1156 | return result;
1157 | };
1158 |
1159 | template
1160 | std::vector>
1161 | searchKnn(const void* query_data, size_t k, Comp comp) {
1162 | std::vector> result;
1163 | if (cur_element_count == 0) return result;
1164 |
1165 | auto ret = searchKnn(query_data, k);
1166 |
1167 | while (!ret.empty()) {
1168 | result.push_back(ret.top());
1169 | ret.pop();
1170 | }
1171 |
1172 | std::sort(result.begin(), result.end(), comp);
1173 |
1174 | return result;
1175 | }
1176 |
1177 | void checkIntegrity(){
1178 | int connections_checked=0;
1179 | std::vector inbound_connections_num(cur_element_count,0);
1180 | for(int i = 0;i < cur_element_count; i++){
1181 | for(int l = 0;l <= element_levels_[i]; l++){
1182 | linklistsizeint *ll_cur = get_linklist_at_level(i,l);
1183 | int size = getListCount(ll_cur);
1184 | tableint *data = (tableint *) (ll_cur + 1);
1185 | std::unordered_set s;
1186 | for (int j=0; j 0);
1188 | assert(data[j] < cur_element_count);
1189 | assert (data[j] != i);
1190 | inbound_connections_num[data[j]]++;
1191 | s.insert(data[j]);
1192 | connections_checked++;
1193 |
1194 | }
1195 | assert(s.size() == size);
1196 | }
1197 | }
1198 | if(cur_element_count > 1){
1199 | int min1=inbound_connections_num[0], max1=inbound_connections_num[0];
1200 | for(int i=0; i < cur_element_count; i++){
1201 | assert(inbound_connections_num[i] > 0);
1202 | min1=std::min(inbound_connections_num[i],min1);
1203 | max1=std::max(inbound_connections_num[i],max1);
1204 | }
1205 | std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n";
1206 | }
1207 | std::cout << "integrity ok, checked " << connections_checked << " connections\n";
1208 |
1209 | }
1210 |
1211 | };
1212 |
1213 | }
1214 |
--------------------------------------------------------------------------------