├── .gitmodules ├── Makefile ├── hnsw_wrapper.h ├── .gitignore ├── hnsw_wrapper.cc ├── hnswlib ├── visited_list_pool.h ├── hnswlib.h ├── bruteforce.h ├── space_l2.h ├── space_ip.h └── hnswalg.h ├── README.md └── hnsw.go /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "hnswlib"] 2 | path = hnswlib 3 | url = https://github.com/nmslib/hnswlib.git 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CXX=g++ 2 | INCLUDES=-I. 3 | CXXFLAGS=-fPIC -pthread -Wall -std=c++0x -std=c++11 -O2 -march=native $(INCLUDES) 4 | LDFLAGS=-shared 5 | OBJS=hnsw_wrapper.o 6 | TARGET=libhnsw.so 7 | 8 | all: $(TARGET) 9 | 10 | $(OBJS): hnsw_wrapper.h hnsw_wrapper.cc hnswlib/*.h 11 | $(CXX) $(CXXFLAGS) -c hnsw_wrapper.cc 12 | 13 | $(TARGET): $(OBJS) 14 | $(CXX) $(LDFLAGS) -o $(TARGET) $(OBJS) 15 | 16 | clean: 17 | rm -rf $(OBJS) $(TARGET) 18 | -------------------------------------------------------------------------------- /hnsw_wrapper.h: -------------------------------------------------------------------------------- 1 | // hnsw_wrapper.h 2 | #ifdef __cplusplus 3 | extern "C" { 4 | #endif 5 | typedef void* HNSW; 6 | HNSW initHNSW(int dim, unsigned long long int max_elements, int M, int ef_construction, int rand_seed, char stype); 7 | HNSW loadHNSW(char *location, int dim, char stype); 8 | HNSW saveHNSW(HNSW index, char *location); 9 | void freeHNSW(HNSW index); 10 | void addPoint(HNSW index, float *vec, unsigned long long int label); 11 | int searchKnn(HNSW index, float *vec, int N, unsigned long long int *label, float *dist); 12 | void setEf(HNSW index, int ef); 13 | #ifdef __cplusplus 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | # Prerequisites 18 | *.d 19 | 20 | # Compiled Object files 21 | *.slo 22 | *.lo 23 | *.o 24 | *.obj 25 | 26 | # Precompiled Headers 27 | *.gch 28 | *.pch 29 | 30 | # Compiled Dynamic libraries 31 | *.so 32 | *.dylib 33 | *.dll 34 | 35 | # Fortran module files 36 | *.mod 37 | *.smod 38 | 39 | # Compiled Static libraries 40 | *.lai 41 | *.la 42 | *.a 43 | *.lib 44 | 45 | # Executables 46 | *.exe 47 | *.out 48 | *.app 49 | -------------------------------------------------------------------------------- /hnsw_wrapper.cc: -------------------------------------------------------------------------------- 1 | //hnsw_wrapper.cpp 2 | #include 3 | #include "hnswlib/hnswlib.h" 4 | #include "hnsw_wrapper.h" 5 | #include 6 | #include 7 | 8 | HNSW initHNSW(int dim, unsigned long long int max_elements, int M, int ef_construction, int rand_seed, char stype) { 9 | hnswlib::SpaceInterface *space; 10 | if (stype == 'i') { 11 | space = new hnswlib::InnerProductSpace(dim); 12 | } else { 13 | space = new hnswlib::L2Space(dim); 14 | } 15 | hnswlib::HierarchicalNSW *appr_alg = new hnswlib::HierarchicalNSW(space, max_elements, M, ef_construction, rand_seed); 16 | return (void*)appr_alg; 17 | } 18 | 19 | HNSW loadHNSW(char *location, int dim, char stype) { 20 | hnswlib::SpaceInterface *space; 21 | if (stype == 'i') { 22 | space = new hnswlib::InnerProductSpace(dim); 23 | } else { 24 | space = new hnswlib::L2Space(dim); 25 | } 26 | hnswlib::HierarchicalNSW *appr_alg = new hnswlib::HierarchicalNSW(space, std::string(location), false, 0); 27 | return (void*)appr_alg; 28 | } 29 | 30 | HNSW saveHNSW(HNSW index, char *location) { 31 | ((hnswlib::HierarchicalNSW*)index)->saveIndex(location); 32 | return ((hnswlib::HierarchicalNSW*)index); 33 | } 34 | 35 | void freeHNSW(HNSW index) { 36 | hnswlib::HierarchicalNSW* ptr = (hnswlib::HierarchicalNSW*) index; 37 | delete ptr; 38 | } 39 | 40 | void addPoint(HNSW index, float *vec, unsigned long long int label) { 41 | ((hnswlib::HierarchicalNSW*)index)->addPoint(vec, label); 42 | } 43 | 44 | int searchKnn(HNSW index, float *vec, int N, unsigned long long int *label, float *dist) { 45 | std::priority_queue> gt; 46 | try { 47 | gt = ((hnswlib::HierarchicalNSW*)index)->searchKnn(vec, N); 48 | } catch (const std::exception& e) { 49 | return 0; 50 | } 51 | 52 | int n = gt.size(); 53 | std::pair pair; 54 | for (int i = n - 1; i >= 0; i--) { 55 | pair = gt.top(); 56 | *(dist+i) = pair.first; 57 | *(label+i) = pair.second; 58 | gt.pop(); 59 | } 60 | return n; 61 | } 62 | 63 | void setEf(HNSW index, int ef) { 64 | ((hnswlib::HierarchicalNSW*)index)->ef_ = ef; 65 | } 66 | -------------------------------------------------------------------------------- /hnswlib/visited_list_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace hnswlib { 7 | typedef unsigned short int vl_type; 8 | 9 | class VisitedList { 10 | public: 11 | vl_type curV; 12 | vl_type *mass; 13 | unsigned int numelements; 14 | 15 | VisitedList(int numelements1) { 16 | curV = -1; 17 | numelements = numelements1; 18 | mass = new vl_type[numelements]; 19 | } 20 | 21 | void reset() { 22 | curV++; 23 | if (curV == 0) { 24 | memset(mass, 0, sizeof(vl_type) * numelements); 25 | curV++; 26 | } 27 | }; 28 | 29 | ~VisitedList() { delete[] mass; } 30 | }; 31 | /////////////////////////////////////////////////////////// 32 | // 33 | // Class for multi-threaded pool-management of VisitedLists 34 | // 35 | ///////////////////////////////////////////////////////// 36 | 37 | class VisitedListPool { 38 | std::deque pool; 39 | std::mutex poolguard; 40 | int numelements; 41 | 42 | public: 43 | VisitedListPool(int initmaxpools, int numelements1) { 44 | numelements = numelements1; 45 | for (int i = 0; i < initmaxpools; i++) 46 | pool.push_front(new VisitedList(numelements)); 47 | } 48 | 49 | VisitedList *getFreeVisitedList() { 50 | VisitedList *rez; 51 | { 52 | std::unique_lock lock(poolguard); 53 | if (pool.size() > 0) { 54 | rez = pool.front(); 55 | pool.pop_front(); 56 | } else { 57 | rez = new VisitedList(numelements); 58 | } 59 | } 60 | rez->reset(); 61 | return rez; 62 | }; 63 | 64 | void releaseVisitedList(VisitedList *vl) { 65 | std::unique_lock lock(poolguard); 66 | pool.push_front(vl); 67 | }; 68 | 69 | ~VisitedListPool() { 70 | while (pool.size()) { 71 | VisitedList *rez = pool.front(); 72 | pool.pop_front(); 73 | delete rez; 74 | } 75 | }; 76 | }; 77 | } 78 | 79 | -------------------------------------------------------------------------------- /hnswlib/hnswlib.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef NO_MANUAL_VECTORIZATION 3 | #ifdef __SSE__ 4 | #define USE_SSE 5 | #ifdef __AVX__ 6 | #define USE_AVX 7 | #endif 8 | #endif 9 | #endif 10 | 11 | #if defined(USE_AVX) || defined(USE_SSE) 12 | #ifdef _MSC_VER 13 | #include 14 | #include 15 | #else 16 | #include 17 | #endif 18 | 19 | #if defined(__GNUC__) 20 | #define PORTABLE_ALIGN32 __attribute__((aligned(32))) 21 | #else 22 | #define PORTABLE_ALIGN32 __declspec(align(32)) 23 | #endif 24 | #endif 25 | 26 | #include 27 | #include 28 | 29 | #include 30 | 31 | namespace hnswlib { 32 | typedef size_t labeltype; 33 | 34 | template 35 | class pairGreater { 36 | public: 37 | bool operator()(const T& p1, const T& p2) { 38 | return p1.first > p2.first; 39 | } 40 | }; 41 | 42 | template 43 | static void writeBinaryPOD(std::ostream &out, const T &podRef) { 44 | out.write((char *) &podRef, sizeof(T)); 45 | } 46 | 47 | template 48 | static void readBinaryPOD(std::istream &in, T &podRef) { 49 | in.read((char *) &podRef, sizeof(T)); 50 | } 51 | 52 | template 53 | using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); 54 | 55 | 56 | template 57 | class SpaceInterface { 58 | public: 59 | //virtual void search(void *); 60 | virtual size_t get_data_size() = 0; 61 | 62 | virtual DISTFUNC get_dist_func() = 0; 63 | 64 | virtual void *get_dist_func_param() = 0; 65 | 66 | virtual ~SpaceInterface() {} 67 | }; 68 | 69 | template 70 | class AlgorithmInterface { 71 | public: 72 | virtual void addPoint(const void *datapoint, labeltype label)=0; 73 | virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; 74 | template 75 | std::vector> searchKnn(const void*, size_t, Comp) { 76 | } 77 | virtual void saveIndex(const std::string &location)=0; 78 | virtual ~AlgorithmInterface(){ 79 | } 80 | }; 81 | 82 | 83 | } 84 | 85 | #include "space_l2.h" 86 | #include "space_ip.h" 87 | #include "bruteforce.h" 88 | #include "hnswalg.h" 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HNSWGO 2 | This is a golang interface of [hnswlib](https://github.com/nmslib/hnswlib). For more information, please follow [hnswlib](https://github.com/nmslib/hnswlib) and [Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs.](https://arxiv.org/abs/1603.09320). 3 | 4 | # Compile (Optional) 5 | ```bash 6 | git clone github.com/evan176/hnswgo.git 7 | cd hnswgo && make 8 | sudo cp libhnsw.so /usr/local/lib 9 | ldconfig 10 | ``` 11 | # Usages 12 | 1. Download shared library 13 | ```bash 14 | sudo wget https://github.com/evan176/hnswgo/releases/download/v1/libhnsw.so -P /usr/local/lib/ 15 | ldconfig 16 | ``` 17 | 2. Export CGO variable 18 | ``` 19 | export CGO_CXXFLAGS=-std=c++11 20 | ``` 21 | 3. Go get 22 | ``` 23 | go get github.com/evan176/hnswgo 24 | ``` 25 | 26 | | argument | type | | 27 | | -------------- | ---- | ----- | 28 | | dim | int | vector dimension | 29 | | M | int | see[ALGO_PARAMS.md](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md) | 30 | | efConstruction | int | see[ALGO_PARAMS.md](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md) | 31 | | randomSeed | int | random seed for hnsw | 32 | | maxElements | int | max records in data | 33 | | spaceType | str | | 34 | 35 | | spaceType | distance | 36 | | --------- |:-----------------:| 37 | | ip | inner product | 38 | | cosine | cosine similarity | 39 | | l2 | l2 | 40 | 41 | ```go 42 | package main 43 | 44 | import ( 45 | "fmt" 46 | "math/rand" 47 | 48 | "github.com/evan176/hnswgo" 49 | ) 50 | 51 | func randVector(dim int) []float32 { 52 | vec := make([]float32, dim) 53 | for j := 0; j < dim; j++ { 54 | vec[j] = rand.Float32() 55 | } 56 | return vec 57 | } 58 | 59 | func main() { 60 | var dim, M, efConstruction int = 128, 32, 300 61 | // Maximum elements need to construct index 62 | var maxElements uint32 = 1000 63 | // Define search space: l2 or ip (innder product) 64 | var spaceType, indexLocation string = "l2", "hnsw_index.bin" 65 | // randomSeed int = 100 66 | // Init new index with 1000 vectors in l2 space 67 | h := hnswgo.New(dim, M, efConstruction, randomSeed, maxElements, spaceType) 68 | 69 | // Insert 1000 vectors to index. Label type is uint32. 70 | var i uint32 71 | for ; i < maxElements; i++ { 72 | h.AddPoint(randVector(dim), i) 73 | } 74 | h.Save(indexLocation) 75 | h = hnswgo.Load(indexLocation, dim, spaceType) 76 | 77 | // Search vector with maximum 10 nearest neighbors 78 | h.setEf(15) 79 | searchVector := randVector(dim) 80 | labels, dists := h.SearchKNN(searchVector, 10) 81 | for i, l := range labels { 82 | fmt.Printf("Nearest label: %d, dist: %f\n", l, dists[i]) 83 | } 84 | } 85 | ``` 86 | 87 | # References 88 | Malkov, Yu A., and D. A. Yashunin. "Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs." TPAMI, preprint: [https://arxiv.org/abs/1603.09320] 89 | -------------------------------------------------------------------------------- /hnsw.go: -------------------------------------------------------------------------------- 1 | package hnswgo 2 | 3 | // #cgo CFLAGS: -I./ 4 | // #cgo LDFLAGS: -lhnsw 5 | // #include 6 | // #include "hnsw_wrapper.h" 7 | import "C" 8 | import ( 9 | "math" 10 | "unsafe" 11 | ) 12 | 13 | type HNSW struct { 14 | index C.HNSW 15 | spaceType string 16 | dim int 17 | normalize bool 18 | } 19 | 20 | func New(dim, M, efConstruction, randSeed int, maxElements uint64, spaceType string) *HNSW { 21 | var hnsw HNSW 22 | hnsw.dim = dim 23 | hnsw.spaceType = spaceType 24 | if spaceType == "ip" { 25 | hnsw.index = C.initHNSW(C.int(dim), C.ulonglong(maxElements), C.int(M), C.int(efConstruction), C.int(randSeed), C.char('i')) 26 | } else if spaceType == "cosine" { 27 | hnsw.normalize = true 28 | hnsw.index = C.initHNSW(C.int(dim), C.ulonglong(maxElements), C.int(M), C.int(efConstruction), C.int(randSeed), C.char('i')) 29 | } else { 30 | hnsw.index = C.initHNSW(C.int(dim), C.ulonglong(maxElements), C.int(M), C.int(efConstruction), C.int(randSeed), C.char('l')) 31 | } 32 | return &hnsw 33 | } 34 | 35 | func Load(location string, dim int, spaceType string) *HNSW { 36 | var hnsw HNSW 37 | hnsw.dim = dim 38 | hnsw.spaceType = spaceType 39 | 40 | pLocation := C.CString(location) 41 | if spaceType == "ip" { 42 | hnsw.index = C.loadHNSW(pLocation, C.int(dim), C.char('i')) 43 | } else if spaceType == "cosine" { 44 | hnsw.normalize = true 45 | hnsw.index = C.loadHNSW(pLocation, C.int(dim), C.char('i')) 46 | } else { 47 | hnsw.index = C.loadHNSW(pLocation, C.int(dim), C.char('l')) 48 | } 49 | C.free(unsafe.Pointer(pLocation)) 50 | return &hnsw 51 | } 52 | 53 | func (h *HNSW) Save(location string) { 54 | pLocation := C.CString(location) 55 | C.saveHNSW(h.index, pLocation) 56 | C.free(unsafe.Pointer(pLocation)) 57 | } 58 | 59 | func (h *HNSW) Free() { 60 | C.freeHNSW(h.index) 61 | } 62 | 63 | func normalizeVector(vector []float32) []float32 { 64 | var norm float32 65 | for i := 0; i < len(vector); i++ { 66 | norm += vector[i] * vector[i] 67 | } 68 | norm = 1.0 / (float32(math.Sqrt(float64(norm))) + 1e-15) 69 | for i := 0; i < len(vector); i++ { 70 | vector[i] = vector[i] * norm 71 | } 72 | return vector 73 | } 74 | 75 | func (h *HNSW) AddPoint(vector []float32, label uint64) { 76 | if h.normalize { 77 | vector = normalizeVector(vector) 78 | } 79 | C.addPoint(h.index, (*C.float)(unsafe.Pointer(&vector[0])), C.ulonglong(label)) 80 | } 81 | 82 | func (h *HNSW) SearchKNN(vector []float32, N int) ([]uint64, []float32) { 83 | Clabel := make([]C.ulonglong, N, N) 84 | Cdist := make([]C.float, N, N) 85 | if h.normalize { 86 | vector = normalizeVector(vector) 87 | } 88 | numResult := int(C.searchKnn(h.index, (*C.float)(unsafe.Pointer(&vector[0])), C.int(N), &Clabel[0], &Cdist[0])) 89 | labels := make([]uint64, N) 90 | dists := make([]float32, N) 91 | for i := 0; i < numResult; i++ { 92 | labels[i] = uint64(Clabel[i]) 93 | dists[i] = float32(Cdist[i]) 94 | } 95 | return labels[:numResult], dists[:numResult] 96 | } 97 | 98 | func (h *HNSW) SetEf(ef int) { 99 | C.setEf(h.index, C.int(ef)) 100 | } 101 | -------------------------------------------------------------------------------- /hnswlib/bruteforce.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace hnswlib { 8 | template 9 | class BruteforceSearch : public AlgorithmInterface { 10 | public: 11 | BruteforceSearch(SpaceInterface *s) { 12 | 13 | } 14 | BruteforceSearch(SpaceInterface *s, const std::string &location) { 15 | loadIndex(location, s); 16 | } 17 | 18 | BruteforceSearch(SpaceInterface *s, size_t maxElements) { 19 | maxelements_ = maxElements; 20 | data_size_ = s->get_data_size(); 21 | fstdistfunc_ = s->get_dist_func(); 22 | dist_func_param_ = s->get_dist_func_param(); 23 | size_per_element_ = data_size_ + sizeof(labeltype); 24 | data_ = (char *) malloc(maxElements * size_per_element_); 25 | if (data_ == nullptr) 26 | std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); 27 | cur_element_count = 0; 28 | } 29 | 30 | ~BruteforceSearch() { 31 | free(data_); 32 | } 33 | 34 | char *data_; 35 | size_t maxelements_; 36 | size_t cur_element_count; 37 | size_t size_per_element_; 38 | 39 | size_t data_size_; 40 | DISTFUNC fstdistfunc_; 41 | void *dist_func_param_; 42 | std::mutex index_lock; 43 | 44 | std::unordered_map dict_external_to_internal; 45 | 46 | void addPoint(const void *datapoint, labeltype label) { 47 | 48 | int idx; 49 | { 50 | std::unique_lock lock(index_lock); 51 | 52 | 53 | 54 | auto search=dict_external_to_internal.find(label); 55 | if (search != dict_external_to_internal.end()) { 56 | idx=search->second; 57 | } 58 | else{ 59 | if (cur_element_count >= maxelements_) { 60 | throw std::runtime_error("The number of elements exceeds the specified limit\n"); 61 | } 62 | idx=cur_element_count; 63 | dict_external_to_internal[label] = idx; 64 | cur_element_count++; 65 | } 66 | } 67 | memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); 68 | memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); 69 | 70 | 71 | 72 | 73 | }; 74 | 75 | void removePoint(labeltype cur_external) { 76 | size_t cur_c=dict_external_to_internal[cur_external]; 77 | 78 | dict_external_to_internal.erase(cur_external); 79 | 80 | labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); 81 | dict_external_to_internal[label]=cur_c; 82 | memcpy(data_ + size_per_element_ * cur_c, 83 | data_ + size_per_element_ * (cur_element_count-1), 84 | data_size_+sizeof(labeltype)); 85 | cur_element_count--; 86 | 87 | } 88 | 89 | 90 | std::priority_queue> 91 | searchKnn(const void *query_data, size_t k) const { 92 | std::priority_queue> topResults; 93 | if (cur_element_count == 0) return topResults; 94 | for (int i = 0; i < k; i++) { 95 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 96 | topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + 97 | data_size_)))); 98 | } 99 | dist_t lastdist = topResults.top().first; 100 | for (int i = k; i < cur_element_count; i++) { 101 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 102 | if (dist <= lastdist) { 103 | topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + 104 | data_size_)))); 105 | if (topResults.size() > k) 106 | topResults.pop(); 107 | lastdist = topResults.top().first; 108 | } 109 | 110 | } 111 | return topResults; 112 | }; 113 | 114 | template 115 | std::vector> 116 | searchKnn(const void* query_data, size_t k, Comp comp) { 117 | std::vector> result; 118 | if (cur_element_count == 0) return result; 119 | 120 | auto ret = searchKnn(query_data, k); 121 | 122 | while (!ret.empty()) { 123 | result.push_back(ret.top()); 124 | ret.pop(); 125 | } 126 | 127 | std::sort(result.begin(), result.end(), comp); 128 | 129 | return result; 130 | } 131 | 132 | void saveIndex(const std::string &location) { 133 | std::ofstream output(location, std::ios::binary); 134 | std::streampos position; 135 | 136 | writeBinaryPOD(output, maxelements_); 137 | writeBinaryPOD(output, size_per_element_); 138 | writeBinaryPOD(output, cur_element_count); 139 | 140 | output.write(data_, maxelements_ * size_per_element_); 141 | 142 | output.close(); 143 | } 144 | 145 | void loadIndex(const std::string &location, SpaceInterface *s) { 146 | 147 | 148 | std::ifstream input(location, std::ios::binary); 149 | std::streampos position; 150 | 151 | readBinaryPOD(input, maxelements_); 152 | readBinaryPOD(input, size_per_element_); 153 | readBinaryPOD(input, cur_element_count); 154 | 155 | data_size_ = s->get_data_size(); 156 | fstdistfunc_ = s->get_dist_func(); 157 | dist_func_param_ = s->get_dist_func_param(); 158 | size_per_element_ = data_size_ + sizeof(labeltype); 159 | data_ = (char *) malloc(maxelements_ * size_per_element_); 160 | if (data_ == nullptr) 161 | std::runtime_error("Not enough memory: loadIndex failed to allocate data"); 162 | 163 | input.read(data_, maxelements_ * size_per_element_); 164 | 165 | input.close(); 166 | 167 | } 168 | 169 | }; 170 | } 171 | -------------------------------------------------------------------------------- /hnswlib/space_l2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace hnswlib { 5 | 6 | static float 7 | L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) { 8 | //return *((float *)pVect2); 9 | size_t qty = *((size_t *) qty_ptr); 10 | float res = 0; 11 | for (unsigned i = 0; i < qty; i++) { 12 | float t = ((float *) pVect1)[i] - ((float *) pVect2)[i]; 13 | res += t * t; 14 | } 15 | return (res); 16 | 17 | } 18 | 19 | #if defined(USE_AVX) 20 | 21 | // Favor using AVX if available. 22 | static float 23 | L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 24 | float *pVect1 = (float *) pVect1v; 25 | float *pVect2 = (float *) pVect2v; 26 | size_t qty = *((size_t *) qty_ptr); 27 | float PORTABLE_ALIGN32 TmpRes[8]; 28 | size_t qty16 = qty >> 4; 29 | 30 | const float *pEnd1 = pVect1 + (qty16 << 4); 31 | 32 | __m256 diff, v1, v2; 33 | __m256 sum = _mm256_set1_ps(0); 34 | 35 | while (pVect1 < pEnd1) { 36 | v1 = _mm256_loadu_ps(pVect1); 37 | pVect1 += 8; 38 | v2 = _mm256_loadu_ps(pVect2); 39 | pVect2 += 8; 40 | diff = _mm256_sub_ps(v1, v2); 41 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 42 | 43 | v1 = _mm256_loadu_ps(pVect1); 44 | pVect1 += 8; 45 | v2 = _mm256_loadu_ps(pVect2); 46 | pVect2 += 8; 47 | diff = _mm256_sub_ps(v1, v2); 48 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 49 | } 50 | 51 | _mm256_store_ps(TmpRes, sum); 52 | float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 53 | 54 | return (res); 55 | } 56 | 57 | #elif defined(USE_SSE) 58 | 59 | static float 60 | L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 61 | float *pVect1 = (float *) pVect1v; 62 | float *pVect2 = (float *) pVect2v; 63 | size_t qty = *((size_t *) qty_ptr); 64 | float PORTABLE_ALIGN32 TmpRes[8]; 65 | // size_t qty4 = qty >> 2; 66 | size_t qty16 = qty >> 4; 67 | 68 | const float *pEnd1 = pVect1 + (qty16 << 4); 69 | // const float* pEnd2 = pVect1 + (qty4 << 2); 70 | // const float* pEnd3 = pVect1 + qty; 71 | 72 | __m128 diff, v1, v2; 73 | __m128 sum = _mm_set1_ps(0); 74 | 75 | while (pVect1 < pEnd1) { 76 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 77 | v1 = _mm_loadu_ps(pVect1); 78 | pVect1 += 4; 79 | v2 = _mm_loadu_ps(pVect2); 80 | pVect2 += 4; 81 | diff = _mm_sub_ps(v1, v2); 82 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 83 | 84 | v1 = _mm_loadu_ps(pVect1); 85 | pVect1 += 4; 86 | v2 = _mm_loadu_ps(pVect2); 87 | pVect2 += 4; 88 | diff = _mm_sub_ps(v1, v2); 89 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 90 | 91 | v1 = _mm_loadu_ps(pVect1); 92 | pVect1 += 4; 93 | v2 = _mm_loadu_ps(pVect2); 94 | pVect2 += 4; 95 | diff = _mm_sub_ps(v1, v2); 96 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 97 | 98 | v1 = _mm_loadu_ps(pVect1); 99 | pVect1 += 4; 100 | v2 = _mm_loadu_ps(pVect2); 101 | pVect2 += 4; 102 | diff = _mm_sub_ps(v1, v2); 103 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 104 | } 105 | _mm_store_ps(TmpRes, sum); 106 | float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 107 | 108 | return (res); 109 | } 110 | #endif 111 | 112 | 113 | #ifdef USE_SSE 114 | static float 115 | L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 116 | float PORTABLE_ALIGN32 TmpRes[8]; 117 | float *pVect1 = (float *) pVect1v; 118 | float *pVect2 = (float *) pVect2v; 119 | size_t qty = *((size_t *) qty_ptr); 120 | 121 | 122 | // size_t qty4 = qty >> 2; 123 | size_t qty16 = qty >> 2; 124 | 125 | const float *pEnd1 = pVect1 + (qty16 << 2); 126 | 127 | __m128 diff, v1, v2; 128 | __m128 sum = _mm_set1_ps(0); 129 | 130 | while (pVect1 < pEnd1) { 131 | v1 = _mm_loadu_ps(pVect1); 132 | pVect1 += 4; 133 | v2 = _mm_loadu_ps(pVect2); 134 | pVect2 += 4; 135 | diff = _mm_sub_ps(v1, v2); 136 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 137 | } 138 | _mm_store_ps(TmpRes, sum); 139 | float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 140 | 141 | return (res); 142 | } 143 | #endif 144 | 145 | class L2Space : public SpaceInterface { 146 | 147 | DISTFUNC fstdistfunc_; 148 | size_t data_size_; 149 | size_t dim_; 150 | public: 151 | L2Space(size_t dim) { 152 | fstdistfunc_ = L2Sqr; 153 | #if defined(USE_SSE) || defined(USE_AVX) 154 | if (dim % 4 == 0) 155 | fstdistfunc_ = L2SqrSIMD4Ext; 156 | if (dim % 16 == 0) 157 | fstdistfunc_ = L2SqrSIMD16Ext; 158 | /*else{ 159 | throw runtime_error("Data type not supported!"); 160 | }*/ 161 | #endif 162 | dim_ = dim; 163 | data_size_ = dim * sizeof(float); 164 | } 165 | 166 | size_t get_data_size() { 167 | return data_size_; 168 | } 169 | 170 | DISTFUNC get_dist_func() { 171 | return fstdistfunc_; 172 | } 173 | 174 | void *get_dist_func_param() { 175 | return &dim_; 176 | } 177 | 178 | ~L2Space() {} 179 | }; 180 | 181 | static int 182 | L2SqrI(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { 183 | 184 | size_t qty = *((size_t *) qty_ptr); 185 | int res = 0; 186 | unsigned char *a = (unsigned char *) pVect1; 187 | unsigned char *b = (unsigned char *) pVect2; 188 | /*for (int i = 0; i < qty; i++) { 189 | int t = int((a)[i]) - int((b)[i]); 190 | res += t*t; 191 | }*/ 192 | 193 | qty = qty >> 2; 194 | for (size_t i = 0; i < qty; i++) { 195 | 196 | res += ((*a) - (*b)) * ((*a) - (*b)); 197 | a++; 198 | b++; 199 | res += ((*a) - (*b)) * ((*a) - (*b)); 200 | a++; 201 | b++; 202 | res += ((*a) - (*b)) * ((*a) - (*b)); 203 | a++; 204 | b++; 205 | res += ((*a) - (*b)) * ((*a) - (*b)); 206 | a++; 207 | b++; 208 | 209 | 210 | } 211 | 212 | return (res); 213 | 214 | } 215 | 216 | class L2SpaceI : public SpaceInterface { 217 | 218 | DISTFUNC fstdistfunc_; 219 | size_t data_size_; 220 | size_t dim_; 221 | public: 222 | L2SpaceI(size_t dim) { 223 | fstdistfunc_ = L2SqrI; 224 | dim_ = dim; 225 | data_size_ = dim * sizeof(unsigned char); 226 | } 227 | 228 | size_t get_data_size() { 229 | return data_size_; 230 | } 231 | 232 | DISTFUNC get_dist_func() { 233 | return fstdistfunc_; 234 | } 235 | 236 | void *get_dist_func_param() { 237 | return &dim_; 238 | } 239 | 240 | ~L2SpaceI() {} 241 | }; 242 | 243 | 244 | } 245 | -------------------------------------------------------------------------------- /hnswlib/space_ip.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace hnswlib { 5 | 6 | static float 7 | InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { 8 | size_t qty = *((size_t *) qty_ptr); 9 | float res = 0; 10 | for (unsigned i = 0; i < qty; i++) { 11 | res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; 12 | } 13 | return (1.0f - res); 14 | 15 | } 16 | 17 | #if defined(USE_AVX) 18 | 19 | // Favor using AVX if available. 20 | static float 21 | InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 22 | float PORTABLE_ALIGN32 TmpRes[8]; 23 | float *pVect1 = (float *) pVect1v; 24 | float *pVect2 = (float *) pVect2v; 25 | size_t qty = *((size_t *) qty_ptr); 26 | 27 | size_t qty16 = qty / 16; 28 | size_t qty4 = qty / 4; 29 | 30 | const float *pEnd1 = pVect1 + 16 * qty16; 31 | const float *pEnd2 = pVect1 + 4 * qty4; 32 | 33 | __m256 sum256 = _mm256_set1_ps(0); 34 | 35 | while (pVect1 < pEnd1) { 36 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 37 | 38 | __m256 v1 = _mm256_loadu_ps(pVect1); 39 | pVect1 += 8; 40 | __m256 v2 = _mm256_loadu_ps(pVect2); 41 | pVect2 += 8; 42 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 43 | 44 | v1 = _mm256_loadu_ps(pVect1); 45 | pVect1 += 8; 46 | v2 = _mm256_loadu_ps(pVect2); 47 | pVect2 += 8; 48 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 49 | } 50 | 51 | __m128 v1, v2; 52 | __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); 53 | 54 | while (pVect1 < pEnd2) { 55 | v1 = _mm_loadu_ps(pVect1); 56 | pVect1 += 4; 57 | v2 = _mm_loadu_ps(pVect2); 58 | pVect2 += 4; 59 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 60 | } 61 | 62 | _mm_store_ps(TmpRes, sum_prod); 63 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];; 64 | return 1.0f - sum; 65 | } 66 | 67 | #elif defined(USE_SSE) 68 | 69 | static float 70 | InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 71 | float PORTABLE_ALIGN32 TmpRes[8]; 72 | float *pVect1 = (float *) pVect1v; 73 | float *pVect2 = (float *) pVect2v; 74 | size_t qty = *((size_t *) qty_ptr); 75 | 76 | size_t qty16 = qty / 16; 77 | size_t qty4 = qty / 4; 78 | 79 | const float *pEnd1 = pVect1 + 16 * qty16; 80 | const float *pEnd2 = pVect1 + 4 * qty4; 81 | 82 | __m128 v1, v2; 83 | __m128 sum_prod = _mm_set1_ps(0); 84 | 85 | while (pVect1 < pEnd1) { 86 | v1 = _mm_loadu_ps(pVect1); 87 | pVect1 += 4; 88 | v2 = _mm_loadu_ps(pVect2); 89 | pVect2 += 4; 90 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 91 | 92 | v1 = _mm_loadu_ps(pVect1); 93 | pVect1 += 4; 94 | v2 = _mm_loadu_ps(pVect2); 95 | pVect2 += 4; 96 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 97 | 98 | v1 = _mm_loadu_ps(pVect1); 99 | pVect1 += 4; 100 | v2 = _mm_loadu_ps(pVect2); 101 | pVect2 += 4; 102 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 103 | 104 | v1 = _mm_loadu_ps(pVect1); 105 | pVect1 += 4; 106 | v2 = _mm_loadu_ps(pVect2); 107 | pVect2 += 4; 108 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 109 | } 110 | 111 | while (pVect1 < pEnd2) { 112 | v1 = _mm_loadu_ps(pVect1); 113 | pVect1 += 4; 114 | v2 = _mm_loadu_ps(pVect2); 115 | pVect2 += 4; 116 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 117 | } 118 | 119 | _mm_store_ps(TmpRes, sum_prod); 120 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 121 | 122 | return 1.0f - sum; 123 | } 124 | 125 | #endif 126 | 127 | #if defined(USE_AVX) 128 | 129 | static float 130 | InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 131 | float PORTABLE_ALIGN32 TmpRes[8]; 132 | float *pVect1 = (float *) pVect1v; 133 | float *pVect2 = (float *) pVect2v; 134 | size_t qty = *((size_t *) qty_ptr); 135 | 136 | size_t qty16 = qty / 16; 137 | 138 | 139 | const float *pEnd1 = pVect1 + 16 * qty16; 140 | 141 | __m256 sum256 = _mm256_set1_ps(0); 142 | 143 | while (pVect1 < pEnd1) { 144 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 145 | 146 | __m256 v1 = _mm256_loadu_ps(pVect1); 147 | pVect1 += 8; 148 | __m256 v2 = _mm256_loadu_ps(pVect2); 149 | pVect2 += 8; 150 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 151 | 152 | v1 = _mm256_loadu_ps(pVect1); 153 | pVect1 += 8; 154 | v2 = _mm256_loadu_ps(pVect2); 155 | pVect2 += 8; 156 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 157 | } 158 | 159 | _mm256_store_ps(TmpRes, sum256); 160 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 161 | 162 | return 1.0f - sum; 163 | } 164 | 165 | #elif defined(USE_SSE) 166 | 167 | static float 168 | InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 169 | float PORTABLE_ALIGN32 TmpRes[8]; 170 | float *pVect1 = (float *) pVect1v; 171 | float *pVect2 = (float *) pVect2v; 172 | size_t qty = *((size_t *) qty_ptr); 173 | 174 | size_t qty16 = qty / 16; 175 | 176 | const float *pEnd1 = pVect1 + 16 * qty16; 177 | 178 | __m128 v1, v2; 179 | __m128 sum_prod = _mm_set1_ps(0); 180 | 181 | while (pVect1 < pEnd1) { 182 | v1 = _mm_loadu_ps(pVect1); 183 | pVect1 += 4; 184 | v2 = _mm_loadu_ps(pVect2); 185 | pVect2 += 4; 186 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 187 | 188 | v1 = _mm_loadu_ps(pVect1); 189 | pVect1 += 4; 190 | v2 = _mm_loadu_ps(pVect2); 191 | pVect2 += 4; 192 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 193 | 194 | v1 = _mm_loadu_ps(pVect1); 195 | pVect1 += 4; 196 | v2 = _mm_loadu_ps(pVect2); 197 | pVect2 += 4; 198 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 199 | 200 | v1 = _mm_loadu_ps(pVect1); 201 | pVect1 += 4; 202 | v2 = _mm_loadu_ps(pVect2); 203 | pVect2 += 4; 204 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 205 | } 206 | _mm_store_ps(TmpRes, sum_prod); 207 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 208 | 209 | return 1.0f - sum; 210 | } 211 | 212 | #endif 213 | 214 | class InnerProductSpace : public SpaceInterface { 215 | 216 | DISTFUNC fstdistfunc_; 217 | size_t data_size_; 218 | size_t dim_; 219 | public: 220 | InnerProductSpace(size_t dim) { 221 | fstdistfunc_ = InnerProduct; 222 | #if defined(USE_AVX) || defined(USE_SSE) 223 | if (dim % 4 == 0) 224 | fstdistfunc_ = InnerProductSIMD4Ext; 225 | if (dim % 16 == 0) 226 | fstdistfunc_ = InnerProductSIMD16Ext; 227 | #endif 228 | dim_ = dim; 229 | data_size_ = dim * sizeof(float); 230 | } 231 | 232 | size_t get_data_size() { 233 | return data_size_; 234 | } 235 | 236 | DISTFUNC get_dist_func() { 237 | return fstdistfunc_; 238 | } 239 | 240 | void *get_dist_func_param() { 241 | return &dim_; 242 | } 243 | 244 | ~InnerProductSpace() {} 245 | }; 246 | 247 | 248 | } 249 | -------------------------------------------------------------------------------- /hnswlib/hnswalg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "visited_list_pool.h" 4 | #include "hnswlib.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | namespace hnswlib { 12 | typedef unsigned int tableint; 13 | typedef unsigned int linklistsizeint; 14 | 15 | template 16 | class HierarchicalNSW : public AlgorithmInterface { 17 | public: 18 | 19 | HierarchicalNSW(SpaceInterface *s) { 20 | 21 | } 22 | 23 | HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { 24 | loadIndex(location, s, max_elements); 25 | } 26 | 27 | HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : 28 | link_list_locks_(max_elements), element_levels_(max_elements) { 29 | max_elements_ = max_elements; 30 | 31 | has_deletions_=false; 32 | data_size_ = s->get_data_size(); 33 | fstdistfunc_ = s->get_dist_func(); 34 | dist_func_param_ = s->get_dist_func_param(); 35 | M_ = M; 36 | maxM_ = M_; 37 | maxM0_ = M_ * 2; 38 | ef_construction_ = std::max(ef_construction,M_); 39 | ef_ = 10; 40 | 41 | level_generator_.seed(random_seed); 42 | 43 | size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); 44 | size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); 45 | offsetData_ = size_links_level0_; 46 | label_offset_ = size_links_level0_ + data_size_; 47 | offsetLevel0_ = 0; 48 | 49 | data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); 50 | if (data_level0_memory_ == nullptr) 51 | throw std::runtime_error("Not enough memory"); 52 | 53 | cur_element_count = 0; 54 | 55 | visited_list_pool_ = new VisitedListPool(1, max_elements); 56 | 57 | 58 | 59 | //initializations for special treatment of the first node 60 | enterpoint_node_ = -1; 61 | maxlevel_ = -1; 62 | 63 | linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); 64 | if (linkLists_ == nullptr) 65 | throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); 66 | size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); 67 | mult_ = 1 / log(1.0 * M_); 68 | revSize_ = 1.0 / mult_; 69 | } 70 | 71 | struct CompareByFirst { 72 | constexpr bool operator()(std::pair const &a, 73 | std::pair const &b) const noexcept { 74 | return a.first < b.first; 75 | } 76 | }; 77 | 78 | ~HierarchicalNSW() { 79 | 80 | free(data_level0_memory_); 81 | for (tableint i = 0; i < cur_element_count; i++) { 82 | if (element_levels_[i] > 0) 83 | free(linkLists_[i]); 84 | } 85 | free(linkLists_); 86 | delete visited_list_pool_; 87 | } 88 | 89 | size_t max_elements_; 90 | size_t cur_element_count; 91 | size_t size_data_per_element_; 92 | size_t size_links_per_element_; 93 | 94 | size_t M_; 95 | size_t maxM_; 96 | size_t maxM0_; 97 | size_t ef_construction_; 98 | 99 | double mult_, revSize_; 100 | int maxlevel_; 101 | 102 | 103 | VisitedListPool *visited_list_pool_; 104 | std::mutex cur_element_count_guard_; 105 | 106 | std::vector link_list_locks_; 107 | tableint enterpoint_node_; 108 | 109 | 110 | size_t size_links_level0_; 111 | size_t offsetData_, offsetLevel0_; 112 | 113 | 114 | char *data_level0_memory_; 115 | char **linkLists_; 116 | std::vector element_levels_; 117 | 118 | size_t data_size_; 119 | 120 | bool has_deletions_; 121 | 122 | 123 | size_t label_offset_; 124 | DISTFUNC fstdistfunc_; 125 | void *dist_func_param_; 126 | std::unordered_map label_lookup_; 127 | 128 | std::default_random_engine level_generator_; 129 | 130 | inline labeltype getExternalLabel(tableint internal_id) const { 131 | labeltype return_label; 132 | memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); 133 | return return_label; 134 | } 135 | 136 | inline void setExternalLabel(tableint internal_id, labeltype label) const { 137 | memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); 138 | } 139 | 140 | inline labeltype *getExternalLabeLp(tableint internal_id) const { 141 | return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); 142 | } 143 | 144 | inline char *getDataByInternalId(tableint internal_id) const { 145 | return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); 146 | } 147 | 148 | int getRandomLevel(double reverse_size) { 149 | std::uniform_real_distribution distribution(0.0, 1.0); 150 | double r = -log(distribution(level_generator_)) * reverse_size; 151 | return (int) r; 152 | } 153 | 154 | std::priority_queue, std::vector>, CompareByFirst> 155 | searchBaseLayer(tableint ep_id, const void *data_point, int layer) { 156 | VisitedList *vl = visited_list_pool_->getFreeVisitedList(); 157 | vl_type *visited_array = vl->mass; 158 | vl_type visited_array_tag = vl->curV; 159 | 160 | std::priority_queue, std::vector>, CompareByFirst> top_candidates; 161 | std::priority_queue, std::vector>, CompareByFirst> candidateSet; 162 | 163 | dist_t lowerBound; 164 | if (!isMarkedDeleted(ep_id)) { 165 | dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); 166 | top_candidates.emplace(dist, ep_id); 167 | lowerBound = dist; 168 | candidateSet.emplace(-dist, ep_id); 169 | } else { 170 | lowerBound = std::numeric_limits::max(); 171 | candidateSet.emplace(-lowerBound, ep_id); 172 | } 173 | visited_array[ep_id] = visited_array_tag; 174 | 175 | while (!candidateSet.empty()) { 176 | std::pair curr_el_pair = candidateSet.top(); 177 | if ((-curr_el_pair.first) > lowerBound) { 178 | break; 179 | } 180 | candidateSet.pop(); 181 | 182 | tableint curNodeNum = curr_el_pair.second; 183 | 184 | std::unique_lock lock(link_list_locks_[curNodeNum]); 185 | 186 | int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); 187 | if (layer == 0) { 188 | data = (int*)get_linklist0(curNodeNum); 189 | } else { 190 | data = (int*)get_linklist(curNodeNum, layer); 191 | // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); 192 | } 193 | size_t size = getListCount((linklistsizeint*)data); 194 | tableint *datal = (tableint *) (data + 1); 195 | #ifdef USE_SSE 196 | _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); 197 | _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); 198 | _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); 199 | _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); 200 | #endif 201 | 202 | for (size_t j = 0; j < size; j++) { 203 | tableint candidate_id = *(datal + j); 204 | // if (candidate_id == 0) continue; 205 | #ifdef USE_SSE 206 | _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); 207 | _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); 208 | #endif 209 | if (visited_array[candidate_id] == visited_array_tag) continue; 210 | visited_array[candidate_id] = visited_array_tag; 211 | char *currObj1 = (getDataByInternalId(candidate_id)); 212 | 213 | dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); 214 | if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { 215 | candidateSet.emplace(-dist1, candidate_id); 216 | #ifdef USE_SSE 217 | _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); 218 | #endif 219 | 220 | if (!isMarkedDeleted(candidate_id)) 221 | top_candidates.emplace(dist1, candidate_id); 222 | 223 | if (top_candidates.size() > ef_construction_) 224 | top_candidates.pop(); 225 | 226 | if (!top_candidates.empty()) 227 | lowerBound = top_candidates.top().first; 228 | } 229 | } 230 | } 231 | visited_list_pool_->releaseVisitedList(vl); 232 | 233 | return top_candidates; 234 | } 235 | 236 | template 237 | std::priority_queue, std::vector>, CompareByFirst> 238 | searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { 239 | VisitedList *vl = visited_list_pool_->getFreeVisitedList(); 240 | vl_type *visited_array = vl->mass; 241 | vl_type visited_array_tag = vl->curV; 242 | 243 | std::priority_queue, std::vector>, CompareByFirst> top_candidates; 244 | std::priority_queue, std::vector>, CompareByFirst> candidate_set; 245 | 246 | dist_t lowerBound; 247 | if (!has_deletions || !isMarkedDeleted(ep_id)) { 248 | dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); 249 | lowerBound = dist; 250 | top_candidates.emplace(dist, ep_id); 251 | candidate_set.emplace(-dist, ep_id); 252 | } else { 253 | lowerBound = std::numeric_limits::max(); 254 | candidate_set.emplace(-lowerBound, ep_id); 255 | } 256 | 257 | visited_array[ep_id] = visited_array_tag; 258 | 259 | while (!candidate_set.empty()) { 260 | 261 | std::pair current_node_pair = candidate_set.top(); 262 | 263 | if ((-current_node_pair.first) > lowerBound) { 264 | break; 265 | } 266 | candidate_set.pop(); 267 | 268 | tableint current_node_id = current_node_pair.second; 269 | int *data = (int *) get_linklist0(current_node_id); 270 | size_t size = getListCount((linklistsizeint*)data); 271 | // bool cur_node_deleted = isMarkedDeleted(current_node_id); 272 | 273 | #ifdef USE_SSE 274 | _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); 275 | _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); 276 | _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); 277 | _mm_prefetch((char *) (data + 2), _MM_HINT_T0); 278 | #endif 279 | 280 | for (size_t j = 1; j <= size; j++) { 281 | int candidate_id = *(data + j); 282 | // if (candidate_id == 0) continue; 283 | #ifdef USE_SSE 284 | _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); 285 | _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, 286 | _MM_HINT_T0);//////////// 287 | #endif 288 | if (!(visited_array[candidate_id] == visited_array_tag)) { 289 | 290 | visited_array[candidate_id] = visited_array_tag; 291 | 292 | char *currObj1 = (getDataByInternalId(candidate_id)); 293 | dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); 294 | 295 | if (top_candidates.size() < ef || lowerBound > dist) { 296 | candidate_set.emplace(-dist, candidate_id); 297 | #ifdef USE_SSE 298 | _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + 299 | offsetLevel0_,/////////// 300 | _MM_HINT_T0);//////////////////////// 301 | #endif 302 | 303 | if (!has_deletions || !isMarkedDeleted(candidate_id)) 304 | top_candidates.emplace(dist, candidate_id); 305 | 306 | if (top_candidates.size() > ef) 307 | top_candidates.pop(); 308 | 309 | if (!top_candidates.empty()) 310 | lowerBound = top_candidates.top().first; 311 | } 312 | } 313 | } 314 | } 315 | 316 | visited_list_pool_->releaseVisitedList(vl); 317 | return top_candidates; 318 | } 319 | 320 | void getNeighborsByHeuristic2( 321 | std::priority_queue, std::vector>, CompareByFirst> &top_candidates, 322 | const size_t M) { 323 | if (top_candidates.size() < M) { 324 | return; 325 | } 326 | std::priority_queue> queue_closest; 327 | std::vector> return_list; 328 | while (top_candidates.size() > 0) { 329 | queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); 330 | top_candidates.pop(); 331 | } 332 | 333 | while (queue_closest.size()) { 334 | if (return_list.size() >= M) 335 | break; 336 | std::pair curent_pair = queue_closest.top(); 337 | dist_t dist_to_query = -curent_pair.first; 338 | queue_closest.pop(); 339 | bool good = true; 340 | for (std::pair second_pair : return_list) { 341 | dist_t curdist = 342 | fstdistfunc_(getDataByInternalId(second_pair.second), 343 | getDataByInternalId(curent_pair.second), 344 | dist_func_param_);; 345 | if (curdist < dist_to_query) { 346 | good = false; 347 | break; 348 | } 349 | } 350 | if (good) { 351 | return_list.push_back(curent_pair); 352 | } 353 | 354 | 355 | } 356 | 357 | for (std::pair curent_pair : return_list) { 358 | 359 | top_candidates.emplace(-curent_pair.first, curent_pair.second); 360 | } 361 | } 362 | 363 | 364 | linklistsizeint *get_linklist0(tableint internal_id) const { 365 | return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); 366 | }; 367 | 368 | linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { 369 | return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); 370 | }; 371 | 372 | linklistsizeint *get_linklist(tableint internal_id, int level) const { 373 | return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); 374 | }; 375 | 376 | void mutuallyConnectNewElement(const void *data_point, tableint cur_c, 377 | std::priority_queue, std::vector>, CompareByFirst> top_candidates, 378 | int level) { 379 | 380 | size_t Mcurmax = level ? maxM_ : maxM0_; 381 | getNeighborsByHeuristic2(top_candidates, M_); 382 | if (top_candidates.size() > M_) 383 | throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); 384 | 385 | std::vector selectedNeighbors; 386 | selectedNeighbors.reserve(M_); 387 | while (top_candidates.size() > 0) { 388 | selectedNeighbors.push_back(top_candidates.top().second); 389 | top_candidates.pop(); 390 | } 391 | 392 | { 393 | linklistsizeint *ll_cur; 394 | if (level == 0) 395 | ll_cur = get_linklist0(cur_c); 396 | else 397 | ll_cur = get_linklist(cur_c, level); 398 | 399 | if (*ll_cur) { 400 | throw std::runtime_error("The newly inserted element should have blank link list"); 401 | } 402 | setListCount(ll_cur,selectedNeighbors.size()); 403 | tableint *data = (tableint *) (ll_cur + 1); 404 | 405 | 406 | for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { 407 | if (data[idx]) 408 | throw std::runtime_error("Possible memory corruption"); 409 | if (level > element_levels_[selectedNeighbors[idx]]) 410 | throw std::runtime_error("Trying to make a link on a non-existent level"); 411 | 412 | data[idx] = selectedNeighbors[idx]; 413 | 414 | } 415 | } 416 | for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { 417 | 418 | std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); 419 | 420 | 421 | linklistsizeint *ll_other; 422 | if (level == 0) 423 | ll_other = get_linklist0(selectedNeighbors[idx]); 424 | else 425 | ll_other = get_linklist(selectedNeighbors[idx], level); 426 | 427 | size_t sz_link_list_other = getListCount(ll_other); 428 | 429 | if (sz_link_list_other > Mcurmax) 430 | throw std::runtime_error("Bad value of sz_link_list_other"); 431 | if (selectedNeighbors[idx] == cur_c) 432 | throw std::runtime_error("Trying to connect an element to itself"); 433 | if (level > element_levels_[selectedNeighbors[idx]]) 434 | throw std::runtime_error("Trying to make a link on a non-existent level"); 435 | 436 | tableint *data = (tableint *) (ll_other + 1); 437 | if (sz_link_list_other < Mcurmax) { 438 | data[sz_link_list_other] = cur_c; 439 | setListCount(ll_other, sz_link_list_other + 1); 440 | } else { 441 | // finding the "weakest" element to replace it with the new one 442 | dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), 443 | dist_func_param_); 444 | // Heuristic: 445 | std::priority_queue, std::vector>, CompareByFirst> candidates; 446 | candidates.emplace(d_max, cur_c); 447 | 448 | for (size_t j = 0; j < sz_link_list_other; j++) { 449 | candidates.emplace( 450 | fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), 451 | dist_func_param_), data[j]); 452 | } 453 | 454 | getNeighborsByHeuristic2(candidates, Mcurmax); 455 | 456 | int indx = 0; 457 | while (candidates.size() > 0) { 458 | data[indx] = candidates.top().second; 459 | candidates.pop(); 460 | indx++; 461 | } 462 | setListCount(ll_other, indx); 463 | // Nearest K: 464 | /*int indx = -1; 465 | for (int j = 0; j < sz_link_list_other; j++) { 466 | dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); 467 | if (d > d_max) { 468 | indx = j; 469 | d_max = d; 470 | } 471 | } 472 | if (indx >= 0) { 473 | data[indx] = cur_c; 474 | } */ 475 | } 476 | 477 | } 478 | } 479 | 480 | std::mutex global; 481 | size_t ef_; 482 | 483 | void setEf(size_t ef) { 484 | ef_ = ef; 485 | } 486 | 487 | 488 | std::priority_queue> searchKnnInternal(void *query_data, int k) { 489 | std::priority_queue> top_candidates; 490 | if (cur_element_count == 0) return top_candidates; 491 | tableint currObj = enterpoint_node_; 492 | dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); 493 | 494 | for (size_t level = maxlevel_; level > 0; level--) { 495 | bool changed = true; 496 | while (changed) { 497 | changed = false; 498 | int *data; 499 | data = (int *) get_linklist(currObj,level); 500 | int size = getListCount(data); 501 | tableint *datal = (tableint *) (data + 1); 502 | for (int i = 0; i < size; i++) { 503 | tableint cand = datal[i]; 504 | if (cand < 0 || cand > max_elements_) 505 | throw std::runtime_error("cand error"); 506 | dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); 507 | 508 | if (d < curdist) { 509 | curdist = d; 510 | currObj = cand; 511 | changed = true; 512 | } 513 | } 514 | } 515 | } 516 | 517 | if (has_deletions_) { 518 | std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, 519 | ef_); 520 | top_candidates.swap(top_candidates1); 521 | } 522 | else{ 523 | std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, 524 | ef_); 525 | top_candidates.swap(top_candidates1); 526 | } 527 | 528 | while (top_candidates.size() > k) { 529 | top_candidates.pop(); 530 | } 531 | return top_candidates; 532 | }; 533 | 534 | void resizeIndex(size_t new_max_elements){ 535 | if (new_max_elements(new_max_elements).swap(link_list_locks_); 547 | 548 | 549 | // Reallocate base layer 550 | char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); 551 | if (data_level0_memory_new == nullptr) 552 | throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); 553 | memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_); 554 | free(data_level0_memory_); 555 | data_level0_memory_=data_level0_memory_new; 556 | 557 | // Reallocate all other layers 558 | char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements); 559 | if (linkLists_new == nullptr) 560 | throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); 561 | memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *)); 562 | free(linkLists_); 563 | linkLists_=linkLists_new; 564 | 565 | max_elements_=new_max_elements; 566 | 567 | } 568 | 569 | void saveIndex(const std::string &location) { 570 | std::ofstream output(location, std::ios::binary); 571 | std::streampos position; 572 | 573 | writeBinaryPOD(output, offsetLevel0_); 574 | writeBinaryPOD(output, max_elements_); 575 | writeBinaryPOD(output, cur_element_count); 576 | writeBinaryPOD(output, size_data_per_element_); 577 | writeBinaryPOD(output, label_offset_); 578 | writeBinaryPOD(output, offsetData_); 579 | writeBinaryPOD(output, maxlevel_); 580 | writeBinaryPOD(output, enterpoint_node_); 581 | writeBinaryPOD(output, maxM_); 582 | 583 | writeBinaryPOD(output, maxM0_); 584 | writeBinaryPOD(output, M_); 585 | writeBinaryPOD(output, mult_); 586 | writeBinaryPOD(output, ef_construction_); 587 | 588 | output.write(data_level0_memory_, cur_element_count * size_data_per_element_); 589 | 590 | for (size_t i = 0; i < cur_element_count; i++) { 591 | unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; 592 | writeBinaryPOD(output, linkListSize); 593 | if (linkListSize) 594 | output.write(linkLists_[i], linkListSize); 595 | } 596 | output.close(); 597 | } 598 | 599 | void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { 600 | 601 | 602 | std::ifstream input(location, std::ios::binary); 603 | 604 | if (!input.is_open()) 605 | throw std::runtime_error("Cannot open file"); 606 | 607 | 608 | // get file size: 609 | input.seekg(0,input.end); 610 | std::streampos total_filesize=input.tellg(); 611 | input.seekg(0,input.beg); 612 | 613 | readBinaryPOD(input, offsetLevel0_); 614 | readBinaryPOD(input, max_elements_); 615 | readBinaryPOD(input, cur_element_count); 616 | 617 | size_t max_elements=max_elements_i; 618 | if(max_elements < cur_element_count) 619 | max_elements = max_elements_; 620 | max_elements_ = max_elements; 621 | readBinaryPOD(input, size_data_per_element_); 622 | readBinaryPOD(input, label_offset_); 623 | readBinaryPOD(input, offsetData_); 624 | readBinaryPOD(input, maxlevel_); 625 | readBinaryPOD(input, enterpoint_node_); 626 | 627 | readBinaryPOD(input, maxM_); 628 | readBinaryPOD(input, maxM0_); 629 | readBinaryPOD(input, M_); 630 | readBinaryPOD(input, mult_); 631 | readBinaryPOD(input, ef_construction_); 632 | 633 | 634 | data_size_ = s->get_data_size(); 635 | fstdistfunc_ = s->get_dist_func(); 636 | dist_func_param_ = s->get_dist_func_param(); 637 | 638 | auto pos=input.tellg(); 639 | 640 | 641 | /// Optional - check if index is ok: 642 | 643 | input.seekg(cur_element_count * size_data_per_element_,input.cur); 644 | for (size_t i = 0; i < cur_element_count; i++) { 645 | if(input.tellg() < 0 || input.tellg()>=total_filesize){ 646 | throw std::runtime_error("Index seems to be corrupted or unsupported"); 647 | } 648 | 649 | unsigned int linkListSize; 650 | readBinaryPOD(input, linkListSize); 651 | if (linkListSize != 0) { 652 | input.seekg(linkListSize,input.cur); 653 | } 654 | } 655 | 656 | // throw exception if it either corrupted or old index 657 | if(input.tellg()!=total_filesize) 658 | throw std::runtime_error("Index seems to be corrupted or unsupported"); 659 | 660 | input.clear(); 661 | 662 | /// Optional check end 663 | 664 | input.seekg(pos,input.beg); 665 | 666 | 667 | data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); 668 | if (data_level0_memory_ == nullptr) 669 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); 670 | input.read(data_level0_memory_, cur_element_count * size_data_per_element_); 671 | 672 | 673 | 674 | 675 | size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); 676 | 677 | 678 | size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); 679 | std::vector(max_elements).swap(link_list_locks_); 680 | 681 | 682 | visited_list_pool_ = new VisitedListPool(1, max_elements); 683 | 684 | 685 | linkLists_ = (char **) malloc(sizeof(void *) * max_elements); 686 | if (linkLists_ == nullptr) 687 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); 688 | element_levels_ = std::vector(max_elements); 689 | revSize_ = 1.0 / mult_; 690 | ef_ = 10; 691 | for (size_t i = 0; i < cur_element_count; i++) { 692 | label_lookup_[getExternalLabel(i)]=i; 693 | unsigned int linkListSize; 694 | readBinaryPOD(input, linkListSize); 695 | if (linkListSize == 0) { 696 | element_levels_[i] = 0; 697 | 698 | linkLists_[i] = nullptr; 699 | } else { 700 | element_levels_[i] = linkListSize / size_links_per_element_; 701 | linkLists_[i] = (char *) malloc(linkListSize); 702 | if (linkLists_[i] == nullptr) 703 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); 704 | input.read(linkLists_[i], linkListSize); 705 | } 706 | } 707 | 708 | has_deletions_=false; 709 | 710 | for (size_t i = 0; i < cur_element_count; i++) { 711 | if(isMarkedDeleted(i)) 712 | has_deletions_=true; 713 | } 714 | 715 | input.close(); 716 | 717 | return; 718 | } 719 | 720 | template 721 | std::vector getDataByLabel(labeltype label) 722 | { 723 | tableint label_c; 724 | auto search = label_lookup_.find(label); 725 | if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { 726 | throw std::runtime_error("Label not found"); 727 | } 728 | label_c = search->second; 729 | 730 | char* data_ptrv = getDataByInternalId(label_c); 731 | size_t dim = *((size_t *) dist_func_param_); 732 | std::vector data; 733 | data_t* data_ptr = (data_t*) data_ptrv; 734 | for (int i = 0; i < dim; i++) { 735 | data.push_back(*data_ptr); 736 | data_ptr += 1; 737 | } 738 | return data; 739 | } 740 | 741 | static const unsigned char DELETE_MARK = 0x01; 742 | // static const unsigned char REUSE_MARK = 0x10; 743 | /** 744 | * Marks an element with the given label deleted, does NOT really change the current graph. 745 | * @param label 746 | */ 747 | void markDelete(labeltype label) 748 | { 749 | has_deletions_=true; 750 | auto search = label_lookup_.find(label); 751 | if (search == label_lookup_.end()) { 752 | throw std::runtime_error("Label not found"); 753 | } 754 | markDeletedInternal(search->second); 755 | } 756 | 757 | /** 758 | * Uses the first 8 bits of the memory for the linked list to store the mark, 759 | * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. 760 | * @param internalId 761 | */ 762 | void markDeletedInternal(tableint internalId) { 763 | unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; 764 | *ll_cur |= DELETE_MARK; 765 | } 766 | 767 | /** 768 | * Remove the deleted mark of the node. 769 | * @param internalId 770 | */ 771 | void unmarkDeletedInternal(tableint internalId) { 772 | unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; 773 | *ll_cur &= ~DELETE_MARK; 774 | } 775 | 776 | /** 777 | * Checks the first 8 bits of the memory to see if the element is marked deleted. 778 | * @param internalId 779 | * @return 780 | */ 781 | bool isMarkedDeleted(tableint internalId) const { 782 | unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; 783 | return *ll_cur & DELETE_MARK; 784 | } 785 | 786 | unsigned short int getListCount(linklistsizeint * ptr) const { 787 | return *((unsigned short int *)ptr); 788 | } 789 | 790 | void setListCount(linklistsizeint * ptr, unsigned short int size) const { 791 | *((unsigned short int*)(ptr))=*((unsigned short int *)&size); 792 | } 793 | 794 | void addPoint(const void *data_point, labeltype label) { 795 | addPoint(data_point, label,-1); 796 | } 797 | 798 | tableint addPoint(const void *data_point, labeltype label, int level) { 799 | tableint cur_c = 0; 800 | { 801 | std::unique_lock lock(cur_element_count_guard_); 802 | if (cur_element_count >= max_elements_) { 803 | throw std::runtime_error("The number of elements exceeds the specified limit"); 804 | }; 805 | 806 | cur_c = cur_element_count; 807 | cur_element_count++; 808 | 809 | auto search = label_lookup_.find(label); 810 | if (search != label_lookup_.end()) { 811 | std::unique_lock lock_el(link_list_locks_[search->second]); 812 | has_deletions_ = true; 813 | markDeletedInternal(search->second); 814 | } 815 | label_lookup_[label] = cur_c; 816 | } 817 | 818 | std::unique_lock lock_el(link_list_locks_[cur_c]); 819 | int curlevel = getRandomLevel(mult_); 820 | if (level > 0) 821 | curlevel = level; 822 | 823 | element_levels_[cur_c] = curlevel; 824 | 825 | 826 | std::unique_lock templock(global); 827 | int maxlevelcopy = maxlevel_; 828 | if (curlevel <= maxlevelcopy) 829 | templock.unlock(); 830 | tableint currObj = enterpoint_node_; 831 | tableint enterpoint_copy = enterpoint_node_; 832 | 833 | 834 | memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); 835 | 836 | // Initialisation of the data and label 837 | memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); 838 | memcpy(getDataByInternalId(cur_c), data_point, data_size_); 839 | 840 | 841 | if (curlevel) { 842 | linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); 843 | if (linkLists_[cur_c] == nullptr) 844 | throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); 845 | memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); 846 | } 847 | 848 | if ((signed)currObj != -1) { 849 | 850 | if (curlevel < maxlevelcopy) { 851 | 852 | dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); 853 | for (int level = maxlevelcopy; level > curlevel; level--) { 854 | 855 | 856 | bool changed = true; 857 | while (changed) { 858 | changed = false; 859 | unsigned int *data; 860 | std::unique_lock lock(link_list_locks_[currObj]); 861 | data = get_linklist(currObj,level); 862 | int size = getListCount(data); 863 | 864 | tableint *datal = (tableint *) (data + 1); 865 | for (int i = 0; i < size; i++) { 866 | tableint cand = datal[i]; 867 | if (cand < 0 || cand > max_elements_) 868 | throw std::runtime_error("cand error"); 869 | dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); 870 | if (d < curdist) { 871 | curdist = d; 872 | currObj = cand; 873 | changed = true; 874 | } 875 | } 876 | } 877 | } 878 | } 879 | 880 | bool epDeleted = isMarkedDeleted(enterpoint_copy); 881 | for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { 882 | if (level > maxlevelcopy || level < 0) // possible? 883 | throw std::runtime_error("Level error"); 884 | 885 | std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( 886 | currObj, data_point, level); 887 | if (epDeleted) { 888 | top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); 889 | if (top_candidates.size() > ef_construction_) 890 | top_candidates.pop(); 891 | } 892 | mutuallyConnectNewElement(data_point, cur_c, top_candidates, level); 893 | 894 | currObj = top_candidates.top().second; 895 | } 896 | 897 | 898 | } else { 899 | // Do nothing for the first element 900 | enterpoint_node_ = 0; 901 | maxlevel_ = curlevel; 902 | 903 | } 904 | 905 | //Releasing lock for the maximum level 906 | if (curlevel > maxlevelcopy) { 907 | enterpoint_node_ = cur_c; 908 | maxlevel_ = curlevel; 909 | } 910 | return cur_c; 911 | }; 912 | 913 | std::priority_queue> 914 | searchKnn(const void *query_data, size_t k) const { 915 | std::priority_queue> result; 916 | if (cur_element_count == 0) return result; 917 | 918 | tableint currObj = enterpoint_node_; 919 | dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); 920 | 921 | for (int level = maxlevel_; level > 0; level--) { 922 | bool changed = true; 923 | while (changed) { 924 | changed = false; 925 | unsigned int *data; 926 | 927 | data = (unsigned int *) get_linklist(currObj, level); 928 | int size = getListCount(data); 929 | tableint *datal = (tableint *) (data + 1); 930 | for (int i = 0; i < size; i++) { 931 | tableint cand = datal[i]; 932 | if (cand < 0 || cand > max_elements_) 933 | throw std::runtime_error("cand error"); 934 | dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); 935 | 936 | if (d < curdist) { 937 | curdist = d; 938 | currObj = cand; 939 | changed = true; 940 | } 941 | } 942 | } 943 | } 944 | 945 | std::priority_queue, std::vector>, CompareByFirst> top_candidates; 946 | if (has_deletions_) { 947 | std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( 948 | currObj, query_data, std::max(ef_, k)); 949 | top_candidates.swap(top_candidates1); 950 | } 951 | else{ 952 | std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( 953 | currObj, query_data, std::max(ef_, k)); 954 | top_candidates.swap(top_candidates1); 955 | } 956 | while (top_candidates.size() > k) { 957 | top_candidates.pop(); 958 | } 959 | while (top_candidates.size() > 0) { 960 | std::pair rez = top_candidates.top(); 961 | result.push(std::pair(rez.first, getExternalLabel(rez.second))); 962 | top_candidates.pop(); 963 | } 964 | return result; 965 | }; 966 | 967 | template 968 | std::vector> 969 | searchKnn(const void* query_data, size_t k, Comp comp) { 970 | std::vector> result; 971 | if (cur_element_count == 0) return result; 972 | 973 | auto ret = searchKnn(query_data, k); 974 | 975 | while (!ret.empty()) { 976 | result.push_back(ret.top()); 977 | ret.pop(); 978 | } 979 | 980 | std::sort(result.begin(), result.end(), comp); 981 | 982 | return result; 983 | } 984 | 985 | }; 986 | 987 | } 988 | --------------------------------------------------------------------------------