├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── benchmark ├── CMakeLists.txt ├── generate_groundtruth.cc ├── serf_arbitrary.cc └── serf_halfbound.cc ├── include ├── base_index.h ├── common │ ├── data_processing.h │ ├── data_wrapper.h │ ├── logger.h │ ├── reader.h │ └── utils.h └── incremental_hnsw │ ├── bruteforce.h │ ├── hnswalg.h │ ├── hnswlib.h │ ├── space_ip.h │ ├── space_l2.h │ └── visited_list_pool.h ├── sample_data ├── deep_10k.fvecs └── deep_query.fvecs └── src ├── base_hnsw ├── bruteforce.h ├── hnswalg.h ├── hnswlib.h ├── space_ip.h ├── space_l2.h └── visited_list_pool.h ├── baselines └── knn_first_hnsw.h ├── common ├── CMakeLists.txt ├── data_processing.cc ├── data_wrapper.cc ├── logger.cc ├── reader.cc └── utils.cc ├── index_base.h ├── range_index_base.h ├── segment_graph_1d.h └── segment_graph_2d.h /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | 3 | build 4 | build/* 5 | 6 | .vscode/* 7 | .vscode -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12) 2 | project(RangeFiltering-ANNS) 3 | include(CheckCXXCompilerFlag) 4 | 5 | find_package(OpenMP) 6 | if (OPENMP_FOUND) 7 | set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 8 | set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 9 | set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 10 | endif() 11 | 12 | MESSAGE(STATUS ${CMAKE_CXX_FLAGS}) 13 | 14 | set(CMAKE_CXX_STANDARD 11) 15 | 16 | MESSAGE(${CMAKE_SYSTEM}) 17 | 18 | 19 | set(CMAKE_TRY_COMPILE_TARGET_TYPE "STATIC_LIBRARY") 20 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 21 | set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -O3 ") 22 | 23 | if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") 24 | SET( CMAKE_CXX_FLAGS "-Ofast -std=c++11 -DHAVE_CXX0X -openmp -fpic -ftree-vectorize" ) 25 | check_cxx_compiler_flag("-march=native" COMPILER_SUPPORT_NATIVE_FLAG) 26 | if(COMPILER_SUPPORT_NATIVE_FLAG) 27 | SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native" ) 28 | message("set -march=native flag") 29 | else() 30 | check_cxx_compiler_flag("-mcpu=apple-m1" COMPILER_SUPPORT_M1_FLAG) 31 | if(COMPILER_SUPPORT_M1_FLAG) 32 | SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=apple-m1" ) 33 | message("set -mcpu=apple-m1 flag") 34 | endif() 35 | endif() 36 | elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 37 | SET( CMAKE_CXX_FLAGS "-Ofast -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) 38 | elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") 39 | SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) 40 | endif() 41 | 42 | 43 | option(NOSIMD "build hnsw on SIMD") 44 | if(NOSIMD) 45 | add_definitions(-DNO_MANUAL_VECTORIZATION) 46 | endif(NOSIMD) 47 | 48 | option(NOPARALLEL "no parallel build") 49 | if(NOPARALLEL) 50 | add_definitions(-DNO_PARALLEL_BUILD) 51 | endif(NOPARALLEL) 52 | 53 | include_directories(${PROJECT_SOURCE_DIR}/include) 54 | include_directories(${PROJECT_SOURCE_DIR}/include/common) 55 | include_directories(${PROJECT_SOURCE_DIR}/src) 56 | 57 | 58 | 59 | add_subdirectory(src/common) 60 | add_subdirectory(benchmark) 61 | 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 rutgers-db 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeRF 2 | 3 | This repo is the implementation of [SeRF: Segment Graph for Range Filtering Approximate Nearest Neighbor Search](https://dl.acm.org/doi/10.1145/3639324). 4 | 5 | | file | description | 6 | |:--:|:--:| 7 | | segment_graph_1d.h | SegmentGraph for halfbounded query | 8 | | segment_graph_2d.h | 2D-SegmentGraph for arbitrary query | 9 | 10 | > [!NOTE] 11 | > [09/09/2024] We've optimized compile options and added SIMD support. By default, construction and search operations now run in parallel with SIMD, resulting in an approximate 1.5x boost in query speed. 12 | 13 | ## Quick Start 14 | 15 | ### Compile and Run 16 | 17 | SeRF can be built from source using CMake. The only dependencies required are C++11 and OpenMP. 18 | 19 | We have tested the following build instructions on macOS (ARM) and Linux (x86-64): 20 | 21 | ```bash 22 | mkdir build && cd build 23 | cmake .. 24 | make 25 | ``` 26 | 27 | Running example benchmark on DEEP dataset: 28 | 29 | ```bash 30 | ./benchmark/serf_halfbound -N 10000 -dataset_path [path_to_deep_base.fvecs] -query_path [path_to_deep_query.fvecs] 31 | ./benchmark/serf_arbitrary -N 10000 -dataset_path [path_to_deep_base.fvecs] -query_path [path_to_deep_query.fvecs] 32 | 33 | # Running on sample dataset under "sample_data" 34 | ./benchmark/serf_halfbound -N 10000 -dataset_path ../sample_data/deep_10k.fvecs -query_path ../sample_data/deep_query.fvecs 35 | ./benchmark/serf_arbitrary -N 10000 -dataset_path ../sample_data/deep_10k.fvecs -query_path ../sample_data/deep_query.fvecs 36 | ``` 37 | 38 | Parameters: 39 | 40 | - `-dataset_path`: The base dataset path for indexing, pre-sorted by search key 41 | - `-query_path`: The query vectors path 42 | - `-N`: The top-N number of vector using for indexing, load all vectors if not specify 43 | 44 | Optional Parameters: 45 | 46 | - `-dataset`: The dataset name, supported values: `deep,yt8m-audio,wiki-image` 47 | - `-index_k`: The maximum number of neighbors of HNSW index. 48 | - `-ef_con`: ef_construction for HNSW index. 49 | - `-ef_max`: The maximum number of neighbors for building SeRF. 50 | 51 | `index_k1`, `ef_con`, `ef_max`, `ef_search` support multiple values, separating by `,`. 52 | 53 | Sample command: 54 | 55 | ```bash 56 | ./benchmark/serf_arbitrary -N 10000 -dataset_path ../sample_data/deep_10k.fvecs -query_path ../sample_data/deep_query.fvecs -ef_search 40,60,80,100 -index_k 16 57 | ``` 58 | 59 | We hardcoded some parameters for the ease of demonstration, you can change them in the code and recompile. 60 | 61 | ## Dataset 62 | 63 | | Dataset | Data type | Dimensions | Search Key | 64 | | :- | :-: | :-: | :-: | 65 | | [DEEP](http://sites.skoltech.ru/compvision/noimi/) | float | 96 | Synthetic | 66 | | [Youtube-Audio](https://research.google.com/youtube8m/download.html) | float | 128 | Video Release Time | 67 | | [WIT-Image](https://www.kaggle.com/c/wikipedia-image-caption/overview) | float | 1024 | Image Size | 68 | 69 | 74 | 75 | ## Reference 76 | 77 | ```text 78 | @article{SeRF, 79 | author = {Zuo, Chaoji and Qiao, Miao and Zhou, Wenchao and Li, Feifei and Deng, Dong}, 80 | title = {SeRF: Segment Graph for Range-Filtering Approximate Nearest Neighbor Search}, 81 | year = {2024}, 82 | issue_date = {February 2024}, 83 | publisher = {Association for Computing Machinery}, 84 | address = {New York, NY, USA}, 85 | volume = {2}, 86 | number = {1}, 87 | url = {https://doi.org/10.1145/3639324}, 88 | doi = {10.1145/3639324} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /benchmark/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_CXX_STANDARD 17) 2 | 3 | add_executable(generate_groundtruth generate_groundtruth.cc) 4 | target_link_libraries(generate_groundtruth UTIL) 5 | target_compile_options(generate_groundtruth PRIVATE -Wall ${OpenMP_CXX_FLAGS}) 6 | target_link_libraries(generate_groundtruth ${OpenMP_CXX_FLAGS}) 7 | target_link_libraries(generate_groundtruth OpenMP::OpenMP_CXX) 8 | 9 | 10 | add_executable(serf_halfbound serf_halfbound.cc) 11 | target_link_libraries(serf_halfbound UTIL) 12 | target_compile_options(serf_halfbound PRIVATE -Wall ${OpenMP_CXX_FLAGS}) 13 | target_link_libraries(serf_halfbound ${OpenMP_CXX_FLAGS}) 14 | target_link_libraries(serf_halfbound OpenMP::OpenMP_CXX) 15 | 16 | add_executable(serf_arbitrary serf_arbitrary.cc) 17 | target_link_libraries(serf_arbitrary UTIL) 18 | target_compile_options(serf_arbitrary PRIVATE -Wall ${OpenMP_CXX_FLAGS}) 19 | target_link_libraries(serf_arbitrary ${OpenMP_CXX_FLAGS}) 20 | target_link_libraries(serf_arbitrary OpenMP::OpenMP_CXX) -------------------------------------------------------------------------------- /benchmark/generate_groundtruth.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @file exp_halfbound.cc 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Benchmark Half-Bounded Range Filter Search 5 | * @date 2023-12-22 6 | * 7 | * @copyright Copyright (c) 2023 8 | */ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "data_processing.h" 18 | #include "data_wrapper.h" 19 | #include "logger.h" 20 | #include "index_base.h" 21 | #include "reader.h" 22 | #include "utils.h" 23 | 24 | #ifdef __linux__ 25 | #include "sys/sysinfo.h" 26 | #include "sys/types.h" 27 | #endif 28 | 29 | using std::cout; 30 | using std::endl; 31 | using std::string; 32 | using std::to_string; 33 | using std::vector; 34 | 35 | int main(int argc, char **argv) { 36 | // Parameters 37 | string dataset = "deep"; 38 | int data_size = 100000; 39 | string dataset_path = ""; 40 | string query_path = ""; 41 | string groundtruth_prefix = ""; 42 | int query_num = 1000; 43 | int query_k = 10; 44 | 45 | for (int i = 0; i < argc; i++) { 46 | string arg = argv[i]; 47 | // if (arg == "-dataset") dataset = string(argv[i + 1]); 48 | if (arg == "-N") data_size = atoi(argv[i + 1]); 49 | if (arg == "-dataset_path") dataset_path = string(argv[i + 1]); 50 | if (arg == "-query_path") query_path = string(argv[i + 1]); 51 | if (arg == "-groundtruth_prefix") groundtruth_prefix = string(argv[i + 1]); 52 | } 53 | 54 | string size_symbol = ""; 55 | if (data_size == 100000) { 56 | size_symbol = "100k"; 57 | } else if (data_size == 1000000) { 58 | size_symbol = "1m"; 59 | } 60 | 61 | DataWrapper data_wrapper(query_num, query_k, dataset, data_size); 62 | data_wrapper.readData(dataset_path, query_path); 63 | 64 | // data_wrapper.generateHalfBoundedQueriesAndGroundtruth( 65 | // true, groundtruth_prefix + "benchmark-groundtruth-deep-" + size_symbol 66 | // + 67 | // "-num1000-k10.halfbounded.cvs"); 68 | // data_wrapper.generateRangeFilteringQueriesAndGroundtruth( 69 | // true, groundtruth_prefix + "benchmark-groundtruth-deep-" + size_symbol 70 | // + 71 | // "-num1000-k10.arbitrary.cvs"); 72 | 73 | data_wrapper.generateHalfBoundedQueriesAndGroundtruthBenchmark( 74 | true, groundtruth_prefix + "benchmark-groundtruth-deep-" + size_symbol + 75 | "-num1000-k10.halfbounded.cvs"); 76 | data_wrapper.generateRangeFilteringQueriesAndGroundtruthBenchmark( 77 | true, groundtruth_prefix + "benchmark-groundtruth-deep-" + size_symbol + 78 | "-num1000-k10.arbitrary.cvs"); 79 | return 0; 80 | } -------------------------------------------------------------------------------- /benchmark/serf_arbitrary.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @file benchmark_arbitrary.cc 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Benchmark Arbitrary Range Filter Search 5 | * @date 2024-11-17 6 | * 7 | * @copyright Copyright (c) 2024 8 | */ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | // #include "baselines/knn_first_hnsw.h" 19 | #include "data_processing.h" 20 | #include "data_wrapper.h" 21 | #include "index_base.h" 22 | #include "logger.h" 23 | #include "reader.h" 24 | #include "segment_graph_2d.h" 25 | #include "utils.h" 26 | 27 | #ifdef __linux__ 28 | #include "sys/sysinfo.h" 29 | #include "sys/types.h" 30 | #endif 31 | 32 | using std::cout; 33 | using std::endl; 34 | using std::string; 35 | using std::to_string; 36 | using std::vector; 37 | 38 | void log_result_recorder( 39 | const std::map> &result_recorder, 40 | const std::map &comparison_recorder, const int amount) { 41 | for (auto it : result_recorder) { 42 | cout << std::setiosflags(ios::fixed) << std::setprecision(4) 43 | << "range: " << it.first 44 | << "\t recall: " << it.second.first / (amount / result_recorder.size()) 45 | << "\t QPS: " << std::setprecision(0) 46 | << (amount / result_recorder.size()) / it.second.second << "\t Comps: " 47 | << comparison_recorder.at(it.first) / (amount / result_recorder.size()) 48 | << endl; 49 | } 50 | } 51 | 52 | int main(int argc, char **argv) { 53 | #ifdef USE_SSE 54 | cout << "Use SSE" << endl; 55 | #endif 56 | #ifdef USE_AVX 57 | cout << "Use AVX" << endl; 58 | #endif 59 | #ifdef USE_AVX512 60 | cout << "Use AVX512" << endl; 61 | #endif 62 | #ifndef NO_PARALLEL_BUILD 63 | cout << "Index Construct Parallelly" << endl; 64 | #endif 65 | 66 | // Parameters 67 | string dataset = "deep"; 68 | int data_size = 100000; 69 | string dataset_path = ""; 70 | string method = ""; 71 | string query_path = ""; 72 | string groundtruth_path = ""; 73 | vector index_k_list = {8}; 74 | vector ef_construction_list = {100}; 75 | int query_num = 1000; 76 | int query_k = 10; 77 | vector ef_max_list = {500}; 78 | vector searchef_para_range_list = {16, 64, 256}; 79 | bool full_range = false; 80 | 81 | string indexk_str = "8"; 82 | string ef_con_str = "100"; 83 | string ef_max_str = "500"; 84 | string ef_search_str = "16,64,256"; 85 | string version = "Benchmark"; 86 | 87 | for (int i = 0; i < argc; i++) { 88 | string arg = argv[i]; 89 | if (arg == "-dataset") dataset = string(argv[i + 1]); 90 | if (arg == "-N") data_size = atoi(argv[i + 1]); 91 | if (arg == "-dataset_path") dataset_path = string(argv[i + 1]); 92 | if (arg == "-query_path") query_path = string(argv[i + 1]); 93 | if (arg == "-groundtruth_path") groundtruth_path = string(argv[i + 1]); 94 | if (arg == "-index_k") indexk_str = string(argv[i + 1]); 95 | if (arg == "-ef_con") ef_con_str = string(argv[i + 1]); 96 | if (arg == "-ef_max") ef_max_str = string(argv[i + 1]); 97 | if (arg == "-ef_search") ef_search_str = string(argv[i + 1]); 98 | if (arg == "-method") method = string(argv[i + 1]); 99 | if (arg == "-full_range") full_range = true; 100 | } 101 | 102 | index_k_list = str2vec(indexk_str); 103 | ef_construction_list = str2vec(ef_con_str); 104 | ef_max_list = str2vec(ef_max_str); 105 | searchef_para_range_list = str2vec(ef_search_str); 106 | 107 | assert(index_k_list.size() != 0); 108 | assert(ef_construction_list.size() != 0); 109 | // assert(groundtruth_path != ""); 110 | 111 | DataWrapper data_wrapper(query_num, query_k, dataset, data_size); 112 | data_wrapper.readData(dataset_path, query_path); 113 | 114 | // Generate groundtruth 115 | if (full_range) 116 | data_wrapper.generateRangeFilteringQueriesAndGroundtruth(false); 117 | else 118 | data_wrapper.generateRangeFilteringQueriesAndGroundtruthBenchmark(false); 119 | // Or you can load groundtruth from the given path 120 | // data_wrapper.LoadGroundtruth(groundtruth_path); 121 | 122 | assert(data_wrapper.query_ids.size() == data_wrapper.query_ranges.size()); 123 | 124 | cout << "index K:" << endl; 125 | print_set(index_k_list); 126 | cout << "ef construction:" << endl; 127 | print_set(ef_construction_list); 128 | cout << "search ef:" << endl; 129 | print_set(searchef_para_range_list); 130 | 131 | data_wrapper.version = version; 132 | 133 | base_hnsw::L2Space ss(data_wrapper.data_dim); 134 | 135 | timeval t1, t2; 136 | 137 | for (unsigned index_k : index_k_list) { 138 | for (unsigned ef_max : ef_max_list) { 139 | for (unsigned ef_construction : ef_construction_list) { 140 | BaseIndex::IndexParams i_params(index_k, ef_construction, 141 | ef_construction, ef_max); 142 | { 143 | cout << endl; 144 | i_params.recursion_type = BaseIndex::IndexParams::MAX_POS; 145 | SeRF::IndexSegmentGraph2D index(&ss, &data_wrapper); 146 | // rangeindex::RecursionIndex index(&ss, &data_wrapper); 147 | BaseIndex::SearchInfo search_info(&data_wrapper, &i_params, "SeRF_2D", 148 | "benchmark"); 149 | 150 | cout << "Method: " << search_info.method << endl; 151 | cout << "parameters: ef_construction ( " + 152 | to_string(i_params.ef_construction) + " ) index-k( " 153 | << i_params.K << ") ef_max (" << i_params.ef_max << ") " 154 | << endl; 155 | gettimeofday(&t1, NULL); 156 | index.buildIndex(&i_params); 157 | gettimeofday(&t2, NULL); 158 | logTime(t1, t2, "Build Index Time"); 159 | cout << "Total # of Neighbors: " << index.index_info->nodes_amount 160 | << endl; 161 | 162 | { 163 | timeval tt3, tt4; 164 | BaseIndex::SearchParams s_params; 165 | s_params.query_K = data_wrapper.query_k; 166 | for (auto one_searchef : searchef_para_range_list) { 167 | s_params.search_ef = one_searchef; 168 | std::map> 169 | result_recorder; // first->precision, second->query_time 170 | std::map comparison_recorder; 171 | gettimeofday(&tt3, NULL); 172 | for (int idx = 0; idx < data_wrapper.query_ids.size(); idx++) { 173 | int one_id = data_wrapper.query_ids.at(idx); 174 | s_params.query_range = 175 | data_wrapper.query_ranges.at(idx).second - 176 | data_wrapper.query_ranges.at(idx).first + 1; 177 | auto res = index.rangeFilteringSearchOutBound( 178 | &s_params, &search_info, data_wrapper.querys.at(one_id), 179 | data_wrapper.query_ranges.at(idx)); 180 | search_info.precision = 181 | countPrecision(data_wrapper.groundtruth.at(idx), res); 182 | result_recorder[s_params.query_range].first += 183 | search_info.precision; 184 | result_recorder[s_params.query_range].second += 185 | search_info.internal_search_time; 186 | comparison_recorder[s_params.query_range] += 187 | search_info.total_comparison; 188 | } 189 | 190 | cout << endl 191 | << "Search ef: " << one_searchef << endl 192 | << "========================" << endl; 193 | log_result_recorder(result_recorder, comparison_recorder, 194 | data_wrapper.query_ids.size()); 195 | cout << "========================" << endl; 196 | logTime(tt3, tt4, "total query time"); 197 | } 198 | } 199 | } 200 | } 201 | } 202 | } 203 | 204 | return 0; 205 | } 206 | -------------------------------------------------------------------------------- /benchmark/serf_halfbound.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @file exp_halfbound.cc 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Benchmark Half-Bounded Range Filter Search 5 | * @date 2023-12-22 6 | * 7 | * @copyright Copyright (c) 2023 8 | */ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // #include "baselines/knn_first_hnsw.h" 20 | #include "data_processing.h" 21 | #include "data_wrapper.h" 22 | #include "index_base.h" 23 | #include "logger.h" 24 | #include "reader.h" 25 | #include "segment_graph_1d.h" 26 | #include "utils.h" 27 | 28 | #ifdef __linux__ 29 | #include "sys/sysinfo.h" 30 | #include "sys/types.h" 31 | #endif 32 | 33 | using std::cout; 34 | using std::endl; 35 | using std::string; 36 | using std::to_string; 37 | using std::vector; 38 | 39 | void log_result_recorder( 40 | const std::map> &result_recorder, 41 | const std::map &comparison_recorder, const int amount) { 42 | for (auto it : result_recorder) { 43 | cout << std::setiosflags(ios::fixed) << std::setprecision(4) 44 | << "range: " << it.first 45 | << "\t recall: " << it.second.first / (amount / result_recorder.size()) 46 | << "\t QPS: " << std::setprecision(0) 47 | << (amount / result_recorder.size()) / it.second.second << "\t Comps: " 48 | << comparison_recorder.at(it.first) / (amount / result_recorder.size()) 49 | << endl; 50 | } 51 | } 52 | 53 | int main(int argc, char **argv) { 54 | #ifdef USE_SSE 55 | cout << "Use SSE" << endl; 56 | #endif 57 | #ifdef USE_AVX 58 | cout << "Use AVX" << endl; 59 | #endif 60 | #ifdef USE_AVX512 61 | cout << "Use AVX512" << endl; 62 | #endif 63 | #ifndef NO_PARALLEL_BUILD 64 | cout << "Index Construct Parallelly" << endl; 65 | #endif 66 | 67 | // Parameters 68 | string dataset = "deep"; 69 | int data_size = 100000; 70 | string dataset_path = ""; 71 | string method = ""; 72 | string query_path = ""; 73 | string groundtruth_path = ""; 74 | vector index_k_list = {8}; 75 | vector ef_construction_list = {400}; 76 | int query_num = 1000; 77 | int query_k = 10; 78 | vector searchef_para_range_list = {16, 64, 256}; 79 | bool full_range = false; 80 | 81 | string indexk_str = "8"; 82 | string ef_con_str = "100"; 83 | string ef_max_str = "500"; 84 | string ef_search_str = "16,64,256"; 85 | string version = "Benchmark"; 86 | 87 | for (int i = 0; i < argc; i++) { 88 | string arg = argv[i]; 89 | if (arg == "-dataset") dataset = string(argv[i + 1]); 90 | if (arg == "-N") data_size = atoi(argv[i + 1]); 91 | if (arg == "-dataset_path") dataset_path = string(argv[i + 1]); 92 | if (arg == "-query_path") query_path = string(argv[i + 1]); 93 | if (arg == "-groundtruth_path") groundtruth_path = string(argv[i + 1]); 94 | if (arg == "-index_k") indexk_str = string(argv[i + 1]); 95 | if (arg == "-ef_con") ef_con_str = string(argv[i + 1]); 96 | if (arg == "-ef_search") ef_search_str = string(argv[i + 1]); 97 | if (arg == "-method") method = string(argv[i + 1]); 98 | if (arg == "-full_range") full_range = true; 99 | } 100 | 101 | index_k_list = str2vec(indexk_str); 102 | ef_construction_list = str2vec(ef_con_str); 103 | searchef_para_range_list = str2vec(ef_search_str); 104 | 105 | assert(index_k_list.size() != 0); 106 | assert(ef_construction_list.size() != 0); 107 | // assert(groundtruth_path != ""); 108 | 109 | DataWrapper data_wrapper(query_num, query_k, dataset, data_size); 110 | data_wrapper.readData(dataset_path, query_path); 111 | 112 | // Generate groundtruth 113 | // Generate groundtruth 114 | if (full_range) 115 | data_wrapper.generateHalfBoundedQueriesAndGroundtruth(false); 116 | else 117 | data_wrapper.generateHalfBoundedQueriesAndGroundtruthBenchmark(false); 118 | 119 | // Or you can load groundtruth from the given path 120 | // data_wrapper.LoadGroundtruth(groundtruth_path); 121 | 122 | assert(data_wrapper.query_ids.size() == data_wrapper.query_ranges.size()); 123 | 124 | 125 | cout << "index K:" << endl; 126 | print_set(index_k_list); 127 | cout << "ef construction:" << endl; 128 | print_set(ef_construction_list); 129 | cout << "search ef:" << endl; 130 | print_set(searchef_para_range_list); 131 | 132 | data_wrapper.version = version; 133 | 134 | base_hnsw::L2Space ss(data_wrapper.data_dim); 135 | 136 | timeval t1, t2; 137 | 138 | for (unsigned index_k : index_k_list) { 139 | for (unsigned ef_construction : ef_construction_list) { 140 | BaseIndex::IndexParams i_params; 141 | i_params.ef_large_for_pruning = 0; 142 | i_params.ef_max = 0; 143 | i_params.ef_construction = ef_construction; 144 | i_params.K = index_k; 145 | { 146 | cout << endl; 147 | i_params.recursion_type = BaseIndex::IndexParams::MAX_POS; 148 | SeRF::IndexSegmentGraph1D index(&ss, &data_wrapper); 149 | BaseIndex::SearchInfo search_info(&data_wrapper, &i_params, "SeRF", 150 | "benchmark"); 151 | cout << "Method: " << search_info.method << endl; 152 | cout << "parameters: ef_construction ( " + 153 | to_string(i_params.ef_construction) + " ) index-k( " 154 | << i_params.K << ")" << endl; 155 | gettimeofday(&t1, NULL); 156 | index.buildIndex(&i_params); 157 | gettimeofday(&t2, NULL); 158 | logTime(t1, t2, "Build Index Time"); 159 | cout << "Total # of Neighbors: " << index.index_info->nodes_amount 160 | << endl; 161 | 162 | { 163 | timeval tt3, tt4; 164 | BaseIndex::SearchParams s_params; 165 | s_params.query_K = data_wrapper.query_k; 166 | for (auto one_searchef : searchef_para_range_list) { 167 | s_params.search_ef = one_searchef; 168 | std::map> 169 | result_recorder; // first->precision, second->query_time 170 | std::map comparison_recorder; 171 | gettimeofday(&tt3, NULL); 172 | for (int idx = 0; idx < data_wrapper.query_ids.size(); idx++) { 173 | int one_id = data_wrapper.query_ids.at(idx); 174 | s_params.query_range = data_wrapper.query_ranges.at(idx).second - 175 | data_wrapper.query_ranges.at(idx).first + 176 | 1; 177 | 178 | auto res = index.rangeFilteringSearchInRange( 179 | &s_params, &search_info, data_wrapper.querys.at(one_id), 180 | data_wrapper.query_ranges.at(idx)); 181 | search_info.precision = 182 | countPrecision(data_wrapper.groundtruth.at(idx), res); 183 | result_recorder[s_params.query_range].first += 184 | search_info.precision; 185 | result_recorder[s_params.query_range].second += 186 | search_info.internal_search_time; 187 | comparison_recorder[s_params.query_range] += 188 | search_info.total_comparison; 189 | } 190 | cout << endl 191 | << "Search ef: " << one_searchef << endl 192 | << "========================" << endl; 193 | log_result_recorder(result_recorder, comparison_recorder, 194 | data_wrapper.query_ids.size()); 195 | cout << "========================" << endl; 196 | logTime(tt3, tt4, "total query time"); 197 | } 198 | } 199 | } 200 | } 201 | } 202 | 203 | // USE 5 points in the benchmark: 0.1% 0.5% 1% 5% 10% 50% 100% 204 | // Not storing the meta results, just output the statics 205 | 206 | return 0; 207 | } -------------------------------------------------------------------------------- /include/base_index.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Base Class of range filtering vector search index 3 | * 4 | * Author: Chaoji Zuo 5 | * Email: chaoji.zuo@rutgers.edu 6 | * Date: June 19, 2023 7 | */ 8 | #pragma once 9 | 10 | #include 11 | #include 12 | 13 | #include "data_vecs.h" 14 | 15 | namespace baseindex { 16 | 17 | static unsigned const default_K = 16; 18 | static unsigned const default_ef_construction = 400; 19 | 20 | class BaseIndex { 21 | public: 22 | BaseIndex(hnswlib_incre::SpaceInterface* s) { 23 | fstdistfunc_ = s->get_dist_func(); 24 | dist_func_param_ = s->get_dist_func_param(); 25 | } 26 | // virtual void buildIndex() = 0; 27 | virtual ~BaseIndex() {} 28 | 29 | hnswlib_incre::DISTFUNC fstdistfunc_; 30 | void* dist_func_param_; 31 | double sort_time; 32 | double nn_build_time; 33 | int num_search_comparison; 34 | 35 | // Indexing parameters 36 | struct IndexParams { 37 | // original params in hnsw 38 | unsigned K; 39 | unsigned ef_construction; 40 | unsigned random_seed; 41 | 42 | IndexParams() 43 | : K(default_K), 44 | ef_construction(default_ef_construction), 45 | random_seed(2023) {} 46 | }; 47 | 48 | struct Searchparams { 49 | // vector weights; 50 | // vector> weights; 51 | 52 | unsigned internal_search_K; 53 | }; 54 | 55 | struct IndexInfo {}; 56 | 57 | struct SearchInfo { 58 | SearchInfo(const DataWrapper* data, const IndexParams* index_params, 59 | const Searchparams* search_params, const string& meth, 60 | const string& ver) { 61 | data_wrapper = data; 62 | index = index_params; 63 | search = search_params; 64 | version = ver; 65 | method = meth; 66 | // save_path = "../exp/search/" + version + "-" + method + "-" + 67 | // data_wrapper->dataset + "-" + 68 | // std::to_string(data_wrapper->data_size) + ".csv"; 69 | Path(ver); 70 | }; 71 | const DataWrapper* data_wrapper; 72 | const IndexParams* index; 73 | const Searchparams* search; 74 | 75 | string method; 76 | string save_path; 77 | string version; 78 | 79 | double time; 80 | double precision; 81 | int query_id; 82 | unsigned break_counter; 83 | double internal_search_time; 84 | size_t visited_num; 85 | double fetch_nns_time; 86 | double cal_dist_time; 87 | double other_process_time; 88 | size_t total_comparison; 89 | size_t visited_num; 90 | size_t path_counter; 91 | 92 | bool is_investigate = false; 93 | 94 | void Path(const string& ver) { 95 | version = ver; 96 | save_path = "../exp/search/" + version + "-" + method + "-" + 97 | data_wrapper->dataset + "-" + 98 | std::to_string(data_wrapper->data_size) + ".csv"; 99 | std::cout << "Save result to :" << save_path << std::endl; 100 | }; 101 | 102 | void SaveCsv() { 103 | std::ofstream file; 104 | file.open(save_path, std::ios_base::app); 105 | if (file) { 106 | file << 107 | // version << "," << method << "," << 108 | time << "," << precision << "," << search->internal_search_K << "," 109 | << internal_search_time << "," << break_counter << "," 110 | << visited_num << "," << data_wrapper->data_size << "," << index->K 111 | << "," << index->ef_construction; 112 | file << "\n"; 113 | } 114 | file.close(); 115 | }; 116 | 117 | void SavePathInvestigate(const float v1, const float v2, const float v3, 118 | const float v4, bool is_new_row = false) { 119 | string investigate_path = "../exp/search/" + version + "-" + method + 120 | "-" + data_wrapper->dataset + "-" + 121 | std::to_string(data_wrapper->data_size) + 122 | "-invetigate-path.csv"; 123 | std::ofstream file; 124 | file.open(investigate_path, std::ios_base::app); 125 | 126 | if (is_new_row) { 127 | file << "\n"; 128 | } else if (file) { 129 | file << v1 << "," << v2 << "," << v3 << "," << v4; 130 | file << "\n"; 131 | } 132 | file.close(); 133 | }; 134 | }; 135 | }; 136 | 137 | } // namespace baseindex 138 | -------------------------------------------------------------------------------- /include/common/data_processing.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file data_processing.h 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Functions for processing data, generating querys and groundtruth 5 | * @date 2023-06-19 6 | * 7 | * @copyright Copyright (c) 2023 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "data_wrapper.h" 13 | #include "utils.h" 14 | 15 | // void SynthesizeQuerys(const vector> &nodes, 16 | // vector> &querys, const int query_num); 17 | 18 | // void SynthesizeQuerys(DataWrapper &data_wrapper, const int query_num); 19 | 20 | // vector greedyNearest(const vector> &nodes1, 21 | // const vector> &nodes2, 22 | // const std::pair, vector> &query, 23 | // const pair &weights, 24 | // const int k_smallest); 25 | 26 | // void calculateGroundtruth(DataWrapper &data_wrapper); 27 | 28 | // void calculateGroundtruthHalfBounded(DataWrapper &runner, bool is_save = false); 29 | -------------------------------------------------------------------------------- /include/common/data_wrapper.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file data_vecs.h 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Control the raw vector and querys 5 | * @date 2023-06-19 6 | * 7 | * @copyright Copyright (c) 2023 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | using std::pair; 16 | using std::string; 17 | using std::vector; 18 | 19 | class DataWrapper { 20 | public: 21 | DataWrapper(int num, int k_, string dataset_name, int data_size_) 22 | : dataset(dataset_name), 23 | data_size(data_size_), 24 | query_num(num), 25 | query_k(k_){}; 26 | const string dataset; 27 | string version; 28 | const int data_size; 29 | const int query_num; 30 | 31 | const int query_k; 32 | size_t data_dim; 33 | 34 | bool is_even_weight; 35 | bool real_keys; 36 | 37 | // TODO: change vector storage to array 38 | vector> nodes; 39 | vector nodes_keys; // search_keys 40 | vector> 41 | querys; // raw querys; less than query_ids and query_ranges; 42 | vector querys_keys; 43 | vector> query_ranges; 44 | vector> groundtruth; 45 | vector query_ids; 46 | void readData(string &dataset_path, string &query_path); 47 | void generateRangeFilteringQueriesAndGroundtruth(bool is_save = false, 48 | const string path = ""); 49 | void generateHalfBoundedQueriesAndGroundtruth(bool is_save = false, 50 | const string path = ""); 51 | void LoadGroundtruth(const string >_path); 52 | 53 | void generateRangeFilteringQueriesAndGroundtruthScalability( 54 | bool is_save = false, const string path = ""); 55 | 56 | void generateHalfBoundedQueriesAndGroundtruthScalability( 57 | bool is_save = false, const string path = ""); 58 | 59 | void generateHalfBoundedQueriesAndGroundtruthBenchmark( 60 | bool is_save_to_file, const string save_path = ""); 61 | 62 | void generateRangeFilteringQueriesAndGroundtruthBenchmark( 63 | bool is_save_to_file, const string save_path = ""); 64 | }; -------------------------------------------------------------------------------- /include/common/logger.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file logger.h 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief for output exp result to csv files 5 | * @date 2023-06-19 6 | * 7 | * @copyright Copyright (c) 2023 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | using std::cout; 18 | using std::endl; 19 | using std::string; 20 | using std::vector; 21 | 22 | // compact hnsw 23 | 24 | void SaveToCSVRow(const string &path, const int idx, const int l_bound, 25 | const int r_bound, const int range, const int K_neighbor, 26 | const int initial_graph_size, const int index_graph_size, 27 | const string &method, const int search_ef, 28 | const double &precision, const double &appr_ratio, 29 | const double &search_time, const int data_size, 30 | const int num_comparison, const double path_time); 31 | 32 | // knn-first 33 | void SaveToCSVRow(const string &path, const int idx, const int l_bound, 34 | const int r_bound, const int range, const int K_neighbor, 35 | const int initial_graph_size, const int index_graph_size, 36 | const string &method, const int search_ef, 37 | const double &precision, const double &appr_ratio, 38 | const double &search_time, const int data_size, 39 | const size_t num_search_comparison, 40 | const double out_bound_candidates, 41 | const double in_bound_candidates); 42 | 43 | void SaveToIndexCSVRow(const string &path, const string &version, 44 | const string &method, const int data_size, 45 | const int initial_graph_size, const int index_graph_size, 46 | const double nn_build_time, const double sort_time, 47 | const double build_time, const double memory, 48 | const int node_amount, const int window_count, 49 | const double index_size); 50 | 51 | // For range filtering, HNSW detail 52 | void SaveToCSVRow(const string &path, const int idx, const int l_bound, 53 | const int r_bound, const int range, const int K_neighbor, 54 | const int initial_graph_size, const int index_graph_size, 55 | const string &method, const int search_ef, 56 | const double &precision, const double &appr_ratio, 57 | const double &search_time, const int data_size, 58 | vector &res, vector &dists); 59 | 60 | // For PQ 61 | void SaveToCSVRow(const string &path, const int idx, const int l_bound, 62 | const int r_bound, const int range, const int K_neighbor, 63 | const int M_pq, const int Ks_pq, const string &method, 64 | const double &precision, const double &appr_ratio, 65 | const double &search_time, const int data_size); -------------------------------------------------------------------------------- /include/common/reader.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file reader.h 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Read Vector data 5 | * @date 2023-04-21 6 | * 7 | * @copyright Copyright (c) 2023 8 | * 9 | */ 10 | #pragma once 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | using std::cout; 22 | using std::endl; 23 | using std::getline; 24 | using std::ifstream; 25 | using std::ios; 26 | using std::string; 27 | using std::vector; 28 | 29 | // Interface (abstract basic class) of iterative reader 30 | class I_ItrReader { 31 | public: 32 | virtual ~I_ItrReader() {} 33 | virtual bool IsEnd() = 0; 34 | virtual std::vector Next() = 0; 35 | }; 36 | 37 | // Iterative reader for fvec file 38 | class FvecsItrReader : I_ItrReader { 39 | public: 40 | FvecsItrReader(std::string filename); 41 | bool IsEnd(); 42 | std::vector Next(); 43 | 44 | private: 45 | FvecsItrReader(); // prohibit default construct 46 | std::ifstream ifs; 47 | std::vector vec; // store the next vec 48 | bool eof_flag; 49 | }; 50 | 51 | // Iterative reader for bvec file 52 | class BvecsItrReader : I_ItrReader { 53 | public: 54 | BvecsItrReader(std::string filename); 55 | bool IsEnd(); 56 | std::vector Next(); // Read bvec, but return vec 57 | private: 58 | BvecsItrReader(); // prohibit default construct 59 | std::ifstream ifs; 60 | std::vector vec; // store the next vec 61 | bool eof_flag; 62 | }; 63 | 64 | // Proxy class 65 | class ItrReader { 66 | public: 67 | // ext must be "fvecs" or "bvecs" 68 | ItrReader(std::string filename, std::string ext); 69 | ~ItrReader(); 70 | 71 | bool IsEnd(); 72 | std::vector Next(); 73 | 74 | private: 75 | ItrReader(); 76 | I_ItrReader *m_reader; 77 | }; 78 | 79 | // Wrapper. Read top-N vectors 80 | // If top_n = -1, then read all vectors 81 | std::vector> ReadTopN(std::string filename, std::string ext, 82 | int top_n = -1); 83 | 84 | void ReadMatFromTxt(const std::string &path, 85 | std::vector> &data, 86 | const int length_limit); 87 | 88 | void ReadMatFromTxtTwitter(const std::string &path, 89 | std::vector> &data, 90 | const int length_limit); 91 | void ReadMatFromTsv(const std::string &path, 92 | std::vector> &data, 93 | const int length_limit); 94 | 95 | void ReadDataWrapper(vector> &raw_data, vector &search_keys, 96 | const string &dataset, string &dataset_path, 97 | const int item_num); 98 | 99 | void ReadDataWrapper(const string &dataset, string &dataset_path, 100 | vector> &raw_data, const int data_size, 101 | string &query_path, vector> &querys, 102 | const int query_size, vector &search_keys); 103 | 104 | void ReadDataWrapper(const string &dataset, string &dataset_path, 105 | vector> &raw_data, const int data_size); 106 | 107 | int YT8M2Int(const string id); 108 | void ReadMatFromTsvYT8M(const string &path, vector> &data, 109 | vector &search_keys, const int length_limit); 110 | 111 | void ReadGroundtruthQuery(std::vector> >, 112 | std::vector> &query_ranges, 113 | std::vector &query_ids, std::string gt_path); -------------------------------------------------------------------------------- /include/common/utils.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file utils.h 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Utils Functions 5 | * @date 2023-04-21 6 | * 7 | * @copyright Copyright (c) 2023 8 | * 9 | */ 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #ifdef __linux__ 29 | #include "sys/sysinfo.h" 30 | #include "sys/types.h" 31 | #elif __APPLE__ 32 | #include 33 | #include 34 | #include 35 | #include 36 | #endif 37 | 38 | using std::cout; 39 | using std::endl; 40 | using std::getline; 41 | using std::ifstream; 42 | using std::ios; 43 | using std::make_pair; 44 | using std::pair; 45 | using std::string; 46 | using std::vector; 47 | 48 | float EuclideanDistance(const vector &lhs, const vector &rhs, 49 | const int &startDim, int lensDim); 50 | 51 | float EuclideanDistance(const vector &lhs, const vector &rhs); 52 | 53 | float EuclideanDistanceSquare(const vector &lhs, 54 | const vector &rhs); 55 | 56 | void AccumulateTime(timeval &t2, timeval &t1, double &val_time); 57 | void CountTime(timeval &t1, timeval &t2, double &val_time); 58 | double CountTime(timeval &t1, timeval &t2); 59 | 60 | // the same to sort_indexes 61 | template 62 | std::vector sort_permutation(const std::vector &vec) { 63 | std::vector p(vec.size()); 64 | std::iota(p.begin(), p.end(), 0); 65 | std::sort(p.begin(), p.end(), 66 | [&](std::size_t i, std::size_t j) { return vec[i] < vec[j]; }); 67 | return p; 68 | } 69 | 70 | // apply permutation 71 | template 72 | void apply_permutation_in_place(std::vector &vec, 73 | const std::vector &p) { 74 | std::vector done(vec.size()); 75 | for (std::size_t i = 0; i < vec.size(); ++i) { 76 | if (done[i]) { 77 | continue; 78 | } 79 | done[i] = true; 80 | std::size_t prev_j = i; 81 | std::size_t j = p[i]; 82 | while (i != j) { 83 | std::swap(vec[prev_j], vec[j]); 84 | done[j] = true; 85 | prev_j = j; 86 | j = p[j]; 87 | } 88 | } 89 | } 90 | 91 | template 92 | vector sort_indexes(const vector &v) { 93 | // initialize original index locations 94 | vector idx(v.size()); 95 | iota(idx.begin(), idx.end(), 0); 96 | 97 | // sort indexes based on comparing values in v 98 | // using std::stable_sort instead of std::sort 99 | // to avoid unnecessary index re-orderings 100 | // when v contains elements of equal values 101 | stable_sort(idx.begin(), idx.end(), 102 | [&v](size_t i1, size_t i2) { return v[i1] < v[i2]; }); 103 | 104 | return idx; 105 | } 106 | 107 | template 108 | vector sort_indexes(const vector &v, const int begin_bias, 109 | const int end_bias) { 110 | // initialize original index locations 111 | vector idx(end_bias - begin_bias); 112 | iota(idx.begin() + begin_bias, idx.begin() + end_bias, 0); 113 | 114 | // sort indexes based on comparing values in v 115 | // using std::stable_sort instead of std::sort 116 | // to avoid unnecessary index re-orderings 117 | // when v contains elements of equal values 118 | stable_sort(idx.begin() + begin_bias, idx.begin() + end_bias, 119 | [&v](size_t i1, size_t i2) { return v[i1] < v[i2]; }); 120 | 121 | return idx; 122 | } 123 | 124 | template 125 | void print_set(const vector &v) { 126 | if (v.size() == 0) { 127 | cout << "ERROR: EMPTY VECTOR!" << endl; 128 | return; 129 | } 130 | cout << "vertex in set: {"; 131 | for (size_t i = 0; i < v.size() - 1; i++) { 132 | cout << v[i] << ", "; 133 | } 134 | cout << v.back() << "}" << endl; 135 | } 136 | 137 | void logTime(timeval &begin, timeval &end, const string &log); 138 | 139 | double countPrecision(const vector &truth, const vector &pred); 140 | double countApproximationRatio(const vector> &raw_data, 141 | const vector &truth, 142 | const vector &pred, 143 | const vector &query); 144 | 145 | void print_memory(); 146 | void record_memory(long long &); 147 | #define _INT_MAX 2147483640 148 | 149 | vector greedyNearest(const vector> &dpts, 150 | const vector query, const int k_smallest); 151 | 152 | // void evaluateKNNG(const vector> >, 153 | // const vector> &knng, const int K, double 154 | // &recall, double &precision); 155 | 156 | void rangeGreedy(const vector> &nodes, const int k_smallest, 157 | const int l_bound, const int r_bound); 158 | 159 | void greedyNearest(const int query_pos, const vector> &dpts, 160 | const int k_smallest, const int l_bound, const int r_bound); 161 | 162 | vector greedyNearest(const vector> &dpts, 163 | const vector query, const int l_bound, 164 | const int r_bound, const int k_smallest); 165 | 166 | void heuristicPrune(const vector> &nodes, 167 | vector> &top_candidates, const size_t M); 168 | 169 | vector str2vec(const string str); -------------------------------------------------------------------------------- /include/incremental_hnsw/bruteforce.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace hnswlib_incre { 8 | template 9 | class BruteforceSearch : public AlgorithmInterface { 10 | public: 11 | BruteforceSearch(SpaceInterface *s) { 12 | 13 | } 14 | BruteforceSearch(SpaceInterface *s, const std::string &location) { 15 | loadIndex(location, s); 16 | } 17 | 18 | BruteforceSearch(SpaceInterface *s, size_t maxElements) { 19 | maxelements_ = maxElements; 20 | data_size_ = s->get_data_size(); 21 | fstdistfunc_ = s->get_dist_func(); 22 | dist_func_param_ = s->get_dist_func_param(); 23 | size_per_element_ = data_size_ + sizeof(labeltype); 24 | data_ = (char *) malloc(maxElements * size_per_element_); 25 | if (data_ == nullptr) 26 | std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); 27 | cur_element_count = 0; 28 | } 29 | 30 | ~BruteforceSearch() { 31 | free(data_); 32 | } 33 | 34 | char *data_; 35 | size_t maxelements_; 36 | size_t cur_element_count; 37 | size_t size_per_element_; 38 | 39 | size_t data_size_; 40 | DISTFUNC fstdistfunc_; 41 | void *dist_func_param_; 42 | std::mutex index_lock; 43 | 44 | std::unordered_map dict_external_to_internal; 45 | 46 | void addPoint(const void *datapoint, labeltype label) { 47 | 48 | int idx; 49 | { 50 | std::unique_lock lock(index_lock); 51 | 52 | 53 | 54 | auto search=dict_external_to_internal.find(label); 55 | if (search != dict_external_to_internal.end()) { 56 | idx=search->second; 57 | } 58 | else{ 59 | if (cur_element_count >= maxelements_) { 60 | throw std::runtime_error("The number of elements exceeds the specified limit\n"); 61 | } 62 | idx=cur_element_count; 63 | dict_external_to_internal[label] = idx; 64 | cur_element_count++; 65 | } 66 | } 67 | memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); 68 | memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); 69 | 70 | 71 | 72 | 73 | }; 74 | 75 | void removePoint(labeltype cur_external) { 76 | size_t cur_c=dict_external_to_internal[cur_external]; 77 | 78 | dict_external_to_internal.erase(cur_external); 79 | 80 | labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); 81 | dict_external_to_internal[label]=cur_c; 82 | memcpy(data_ + size_per_element_ * cur_c, 83 | data_ + size_per_element_ * (cur_element_count-1), 84 | data_size_+sizeof(labeltype)); 85 | cur_element_count--; 86 | 87 | } 88 | 89 | 90 | std::priority_queue> 91 | searchKnn(const void *query_data, size_t k) const { 92 | std::priority_queue> topResults; 93 | if (cur_element_count == 0) return topResults; 94 | for (int i = 0; i < k; i++) { 95 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 96 | topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + 97 | data_size_)))); 98 | } 99 | dist_t lastdist = topResults.top().first; 100 | for (int i = k; i < cur_element_count; i++) { 101 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 102 | if (dist <= lastdist) { 103 | topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + 104 | data_size_)))); 105 | if (topResults.size() > k) 106 | topResults.pop(); 107 | lastdist = topResults.top().first; 108 | } 109 | 110 | } 111 | return topResults; 112 | }; 113 | 114 | void saveIndex(const std::string &location) { 115 | std::ofstream output(location, std::ios::binary); 116 | std::streampos position; 117 | 118 | writeBinaryPOD(output, maxelements_); 119 | writeBinaryPOD(output, size_per_element_); 120 | writeBinaryPOD(output, cur_element_count); 121 | 122 | output.write(data_, maxelements_ * size_per_element_); 123 | 124 | output.close(); 125 | } 126 | 127 | void loadIndex(const std::string &location, SpaceInterface *s) { 128 | 129 | 130 | std::ifstream input(location, std::ios::binary); 131 | std::streampos position; 132 | 133 | readBinaryPOD(input, maxelements_); 134 | readBinaryPOD(input, size_per_element_); 135 | readBinaryPOD(input, cur_element_count); 136 | 137 | data_size_ = s->get_data_size(); 138 | fstdistfunc_ = s->get_dist_func(); 139 | dist_func_param_ = s->get_dist_func_param(); 140 | size_per_element_ = data_size_ + sizeof(labeltype); 141 | data_ = (char *) malloc(maxelements_ * size_per_element_); 142 | if (data_ == nullptr) 143 | std::runtime_error("Not enough memory: loadIndex failed to allocate data"); 144 | 145 | input.read(data_, maxelements_ * size_per_element_); 146 | 147 | input.close(); 148 | 149 | } 150 | 151 | }; 152 | } 153 | -------------------------------------------------------------------------------- /include/incremental_hnsw/hnswlib.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef HNSW_INCRE_ 4 | #define HNSW_INCRE_ 5 | #ifndef NO_MANUAL_VECTORIZATION 6 | #ifdef __SSE__ 7 | #define USE_SSE 8 | #ifdef __AVX__ 9 | #define USE_AVX 10 | #endif 11 | #endif 12 | #endif 13 | 14 | #if defined(USE_AVX) || defined(USE_SSE) 15 | #ifdef _MSC_VER 16 | #include 17 | #include 18 | #else 19 | #include 20 | #endif 21 | 22 | #if defined(__GNUC__) 23 | #define PORTABLE_ALIGN32 __attribute__((aligned(32))) 24 | #else 25 | #define PORTABLE_ALIGN32 __declspec(align(32)) 26 | #endif 27 | #endif 28 | 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | using std::vector; 35 | 36 | namespace hnswlib_incre { 37 | typedef size_t labeltype; 38 | 39 | template class pairGreater { 40 | public: 41 | bool operator()(const T &p1, const T &p2) { return p1.first > p2.first; } 42 | }; 43 | 44 | template 45 | static void writeBinaryPOD(std::ostream &out, const T &podRef) { 46 | out.write((char *)&podRef, sizeof(T)); 47 | } 48 | 49 | template static void readBinaryPOD(std::istream &in, T &podRef) { 50 | in.read((char *)&podRef, sizeof(T)); 51 | } 52 | 53 | template 54 | using DISTFUNC = MTYPE (*)(const void *, const void *, const void *); 55 | 56 | template class SpaceInterface { 57 | public: 58 | // virtual void search(void *); 59 | virtual size_t get_data_size() = 0; 60 | 61 | virtual DISTFUNC get_dist_func() = 0; 62 | 63 | virtual void *get_dist_func_param() = 0; 64 | 65 | virtual ~SpaceInterface() {} 66 | }; 67 | 68 | template class AlgorithmInterface { 69 | public: 70 | void linkNeighbors(const void *data_point, labeltype label, 71 | vector neighbors) {} 72 | void addNeighborPoint(const void *data_point, labeltype label, int level) {} 73 | virtual void addPoint(const void *datapoint, labeltype label) = 0; 74 | // virtual std::priority_queue> 75 | // searchKnn(const void *, size_t,const int lbound,const int rbound, const int K_query) const; 76 | virtual std::priority_queue> 77 | searchKnnEF(const void *, size_t,const int lbound,const int rbound, const int K_query, const bool fixed_ef) const = 0; 78 | 79 | // Return k nearest neighbor in the order of closer fist 80 | virtual std::vector> 81 | searchKnnCloserFirst(const void *query_data, size_t k, const int lbound, 82 | const int rbound) const; 83 | 84 | 85 | virtual std::vector> 86 | searchKnnCloserFirst(const void *query_data, size_t k, const int lbound, 87 | const int rbound, const bool fixed_ef) const; 88 | 89 | virtual void saveIndex(const std::string &location) = 0; 90 | virtual ~AlgorithmInterface() {} 91 | }; 92 | 93 | template 94 | std::vector> 95 | AlgorithmInterface::searchKnnCloserFirst(const void *query_data, 96 | size_t k, const int lbound, 97 | const int rbound) const { 98 | std::vector> result; 99 | 100 | // here searchKnn returns the result in the order of further first 101 | auto ret = searchKnnEF(query_data, k, lbound, rbound, (int)k, false); 102 | { 103 | size_t sz = ret.size(); 104 | result.resize(sz); 105 | while (!ret.empty()) { 106 | result[--sz] = ret.top(); 107 | ret.pop(); 108 | } 109 | } 110 | 111 | return result; 112 | } 113 | 114 | template 115 | std::vector> 116 | AlgorithmInterface::searchKnnCloserFirst(const void *query_data, 117 | size_t k, const int lbound, 118 | const int rbound, const bool fixed_ef) const { 119 | std::vector> result; 120 | 121 | // here searchKnn returns the result in the order of further first 122 | auto ret = searchKnnEF(query_data, k, lbound, rbound, (int)k, fixed_ef); 123 | { 124 | size_t sz = ret.size(); 125 | result.resize(sz); 126 | while (!ret.empty()) { 127 | result[--sz] = ret.top(); 128 | ret.pop(); 129 | } 130 | } 131 | 132 | return result; 133 | } 134 | 135 | } // namespace hnswlib_incre 136 | 137 | #include "space_ip.h" 138 | #include "space_l2.h" 139 | // #include "bruteforce.h" 140 | #include "hnswalg.h" 141 | 142 | #endif -------------------------------------------------------------------------------- /include/incremental_hnsw/space_ip.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace hnswlib_incre { 5 | 6 | static float 7 | InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { 8 | size_t qty = *((size_t *) qty_ptr); 9 | float res = 0; 10 | for (unsigned i = 0; i < qty; i++) { 11 | res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; 12 | } 13 | return (1.0f - res); 14 | 15 | } 16 | 17 | #if defined(USE_AVX) 18 | 19 | // Favor using AVX if available. 20 | static float 21 | InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 22 | float PORTABLE_ALIGN32 TmpRes[8]; 23 | float *pVect1 = (float *) pVect1v; 24 | float *pVect2 = (float *) pVect2v; 25 | size_t qty = *((size_t *) qty_ptr); 26 | 27 | size_t qty16 = qty / 16; 28 | size_t qty4 = qty / 4; 29 | 30 | const float *pEnd1 = pVect1 + 16 * qty16; 31 | const float *pEnd2 = pVect1 + 4 * qty4; 32 | 33 | __m256 sum256 = _mm256_set1_ps(0); 34 | 35 | while (pVect1 < pEnd1) { 36 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 37 | 38 | __m256 v1 = _mm256_loadu_ps(pVect1); 39 | pVect1 += 8; 40 | __m256 v2 = _mm256_loadu_ps(pVect2); 41 | pVect2 += 8; 42 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 43 | 44 | v1 = _mm256_loadu_ps(pVect1); 45 | pVect1 += 8; 46 | v2 = _mm256_loadu_ps(pVect2); 47 | pVect2 += 8; 48 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 49 | } 50 | 51 | __m128 v1, v2; 52 | __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); 53 | 54 | while (pVect1 < pEnd2) { 55 | v1 = _mm_loadu_ps(pVect1); 56 | pVect1 += 4; 57 | v2 = _mm_loadu_ps(pVect2); 58 | pVect2 += 4; 59 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 60 | } 61 | 62 | _mm_store_ps(TmpRes, sum_prod); 63 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];; 64 | return 1.0f - sum; 65 | } 66 | 67 | #elif defined(USE_SSE) 68 | 69 | static float 70 | InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 71 | float PORTABLE_ALIGN32 TmpRes[8]; 72 | float *pVect1 = (float *) pVect1v; 73 | float *pVect2 = (float *) pVect2v; 74 | size_t qty = *((size_t *) qty_ptr); 75 | 76 | size_t qty16 = qty / 16; 77 | size_t qty4 = qty / 4; 78 | 79 | const float *pEnd1 = pVect1 + 16 * qty16; 80 | const float *pEnd2 = pVect1 + 4 * qty4; 81 | 82 | __m128 v1, v2; 83 | __m128 sum_prod = _mm_set1_ps(0); 84 | 85 | while (pVect1 < pEnd1) { 86 | v1 = _mm_loadu_ps(pVect1); 87 | pVect1 += 4; 88 | v2 = _mm_loadu_ps(pVect2); 89 | pVect2 += 4; 90 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 91 | 92 | v1 = _mm_loadu_ps(pVect1); 93 | pVect1 += 4; 94 | v2 = _mm_loadu_ps(pVect2); 95 | pVect2 += 4; 96 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 97 | 98 | v1 = _mm_loadu_ps(pVect1); 99 | pVect1 += 4; 100 | v2 = _mm_loadu_ps(pVect2); 101 | pVect2 += 4; 102 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 103 | 104 | v1 = _mm_loadu_ps(pVect1); 105 | pVect1 += 4; 106 | v2 = _mm_loadu_ps(pVect2); 107 | pVect2 += 4; 108 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 109 | } 110 | 111 | while (pVect1 < pEnd2) { 112 | v1 = _mm_loadu_ps(pVect1); 113 | pVect1 += 4; 114 | v2 = _mm_loadu_ps(pVect2); 115 | pVect2 += 4; 116 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 117 | } 118 | 119 | _mm_store_ps(TmpRes, sum_prod); 120 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 121 | 122 | return 1.0f - sum; 123 | } 124 | 125 | #endif 126 | 127 | #if defined(USE_AVX) 128 | 129 | static float 130 | InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 131 | float PORTABLE_ALIGN32 TmpRes[8]; 132 | float *pVect1 = (float *) pVect1v; 133 | float *pVect2 = (float *) pVect2v; 134 | size_t qty = *((size_t *) qty_ptr); 135 | 136 | size_t qty16 = qty / 16; 137 | 138 | 139 | const float *pEnd1 = pVect1 + 16 * qty16; 140 | 141 | __m256 sum256 = _mm256_set1_ps(0); 142 | 143 | while (pVect1 < pEnd1) { 144 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 145 | 146 | __m256 v1 = _mm256_loadu_ps(pVect1); 147 | pVect1 += 8; 148 | __m256 v2 = _mm256_loadu_ps(pVect2); 149 | pVect2 += 8; 150 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 151 | 152 | v1 = _mm256_loadu_ps(pVect1); 153 | pVect1 += 8; 154 | v2 = _mm256_loadu_ps(pVect2); 155 | pVect2 += 8; 156 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 157 | } 158 | 159 | _mm256_store_ps(TmpRes, sum256); 160 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 161 | 162 | return 1.0f - sum; 163 | } 164 | 165 | #elif defined(USE_SSE) 166 | 167 | static float 168 | InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 169 | float PORTABLE_ALIGN32 TmpRes[8]; 170 | float *pVect1 = (float *) pVect1v; 171 | float *pVect2 = (float *) pVect2v; 172 | size_t qty = *((size_t *) qty_ptr); 173 | 174 | size_t qty16 = qty / 16; 175 | 176 | const float *pEnd1 = pVect1 + 16 * qty16; 177 | 178 | __m128 v1, v2; 179 | __m128 sum_prod = _mm_set1_ps(0); 180 | 181 | while (pVect1 < pEnd1) { 182 | v1 = _mm_loadu_ps(pVect1); 183 | pVect1 += 4; 184 | v2 = _mm_loadu_ps(pVect2); 185 | pVect2 += 4; 186 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 187 | 188 | v1 = _mm_loadu_ps(pVect1); 189 | pVect1 += 4; 190 | v2 = _mm_loadu_ps(pVect2); 191 | pVect2 += 4; 192 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 193 | 194 | v1 = _mm_loadu_ps(pVect1); 195 | pVect1 += 4; 196 | v2 = _mm_loadu_ps(pVect2); 197 | pVect2 += 4; 198 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 199 | 200 | v1 = _mm_loadu_ps(pVect1); 201 | pVect1 += 4; 202 | v2 = _mm_loadu_ps(pVect2); 203 | pVect2 += 4; 204 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 205 | } 206 | _mm_store_ps(TmpRes, sum_prod); 207 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 208 | 209 | return 1.0f - sum; 210 | } 211 | 212 | #endif 213 | 214 | #if defined(USE_SSE) || defined(USE_AVX) 215 | static float 216 | InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 217 | size_t qty = *((size_t *) qty_ptr); 218 | size_t qty16 = qty >> 4 << 4; 219 | float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); 220 | float *pVect1 = (float *) pVect1v + qty16; 221 | float *pVect2 = (float *) pVect2v + qty16; 222 | 223 | size_t qty_left = qty - qty16; 224 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left); 225 | return res + res_tail - 1.0f; 226 | } 227 | 228 | static float 229 | InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 230 | size_t qty = *((size_t *) qty_ptr); 231 | size_t qty4 = qty >> 2 << 2; 232 | 233 | float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); 234 | size_t qty_left = qty - qty4; 235 | 236 | float *pVect1 = (float *) pVect1v + qty4; 237 | float *pVect2 = (float *) pVect2v + qty4; 238 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left); 239 | 240 | return res + res_tail - 1.0f; 241 | } 242 | #endif 243 | 244 | class InnerProductSpace : public SpaceInterface { 245 | 246 | DISTFUNC fstdistfunc_; 247 | size_t data_size_; 248 | size_t dim_; 249 | public: 250 | InnerProductSpace(size_t dim) { 251 | fstdistfunc_ = InnerProduct; 252 | #if defined(USE_AVX) || defined(USE_SSE) 253 | if (dim % 16 == 0) 254 | fstdistfunc_ = InnerProductSIMD16Ext; 255 | else if (dim % 4 == 0) 256 | fstdistfunc_ = InnerProductSIMD4Ext; 257 | else if (dim > 16) 258 | fstdistfunc_ = InnerProductSIMD16ExtResiduals; 259 | else if (dim > 4) 260 | fstdistfunc_ = InnerProductSIMD4ExtResiduals; 261 | #endif 262 | dim_ = dim; 263 | data_size_ = dim * sizeof(float); 264 | } 265 | 266 | size_t get_data_size() { 267 | return data_size_; 268 | } 269 | 270 | DISTFUNC get_dist_func() { 271 | return fstdistfunc_; 272 | } 273 | 274 | void *get_dist_func_param() { 275 | return &dim_; 276 | } 277 | 278 | ~InnerProductSpace() {} 279 | }; 280 | 281 | 282 | } 283 | -------------------------------------------------------------------------------- /include/incremental_hnsw/space_l2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace hnswlib_incre { 5 | 6 | static float 7 | L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 8 | // std::cout<<"isme"<> 4; 33 | 34 | const float *pEnd1 = pVect1 + (qty16 << 4); 35 | 36 | __m256 diff, v1, v2; 37 | __m256 sum = _mm256_set1_ps(0); 38 | 39 | while (pVect1 < pEnd1) { 40 | v1 = _mm256_loadu_ps(pVect1); 41 | pVect1 += 8; 42 | v2 = _mm256_loadu_ps(pVect2); 43 | pVect2 += 8; 44 | diff = _mm256_sub_ps(v1, v2); 45 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 46 | 47 | v1 = _mm256_loadu_ps(pVect1); 48 | pVect1 += 8; 49 | v2 = _mm256_loadu_ps(pVect2); 50 | pVect2 += 8; 51 | diff = _mm256_sub_ps(v1, v2); 52 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 53 | } 54 | 55 | _mm256_store_ps(TmpRes, sum); 56 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 57 | } 58 | 59 | #elif defined(USE_SSE) 60 | 61 | static float 62 | L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 63 | float *pVect1 = (float *) pVect1v; 64 | float *pVect2 = (float *) pVect2v; 65 | size_t qty = *((size_t *) qty_ptr); 66 | float PORTABLE_ALIGN32 TmpRes[8]; 67 | size_t qty16 = qty >> 4; 68 | 69 | const float *pEnd1 = pVect1 + (qty16 << 4); 70 | 71 | __m128 diff, v1, v2; 72 | __m128 sum = _mm_set1_ps(0); 73 | 74 | while (pVect1 < pEnd1) { 75 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 76 | v1 = _mm_loadu_ps(pVect1); 77 | pVect1 += 4; 78 | v2 = _mm_loadu_ps(pVect2); 79 | pVect2 += 4; 80 | diff = _mm_sub_ps(v1, v2); 81 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 82 | 83 | v1 = _mm_loadu_ps(pVect1); 84 | pVect1 += 4; 85 | v2 = _mm_loadu_ps(pVect2); 86 | pVect2 += 4; 87 | diff = _mm_sub_ps(v1, v2); 88 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 89 | 90 | v1 = _mm_loadu_ps(pVect1); 91 | pVect1 += 4; 92 | v2 = _mm_loadu_ps(pVect2); 93 | pVect2 += 4; 94 | diff = _mm_sub_ps(v1, v2); 95 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 96 | 97 | v1 = _mm_loadu_ps(pVect1); 98 | pVect1 += 4; 99 | v2 = _mm_loadu_ps(pVect2); 100 | pVect2 += 4; 101 | diff = _mm_sub_ps(v1, v2); 102 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 103 | } 104 | 105 | _mm_store_ps(TmpRes, sum); 106 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 107 | } 108 | #endif 109 | 110 | #if defined(USE_SSE) || defined(USE_AVX) 111 | static float 112 | L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 113 | size_t qty = *((size_t *) qty_ptr); 114 | size_t qty16 = qty >> 4 << 4; 115 | float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); 116 | float *pVect1 = (float *) pVect1v + qty16; 117 | float *pVect2 = (float *) pVect2v + qty16; 118 | 119 | size_t qty_left = qty - qty16; 120 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 121 | return (res + res_tail); 122 | } 123 | #endif 124 | 125 | 126 | #ifdef USE_SSE 127 | static float 128 | L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 129 | float PORTABLE_ALIGN32 TmpRes[8]; 130 | float *pVect1 = (float *) pVect1v; 131 | float *pVect2 = (float *) pVect2v; 132 | size_t qty = *((size_t *) qty_ptr); 133 | 134 | 135 | size_t qty4 = qty >> 2; 136 | 137 | const float *pEnd1 = pVect1 + (qty4 << 2); 138 | 139 | __m128 diff, v1, v2; 140 | __m128 sum = _mm_set1_ps(0); 141 | 142 | while (pVect1 < pEnd1) { 143 | v1 = _mm_loadu_ps(pVect1); 144 | pVect1 += 4; 145 | v2 = _mm_loadu_ps(pVect2); 146 | pVect2 += 4; 147 | diff = _mm_sub_ps(v1, v2); 148 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 149 | } 150 | _mm_store_ps(TmpRes, sum); 151 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 152 | } 153 | 154 | static float 155 | L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 156 | size_t qty = *((size_t *) qty_ptr); 157 | size_t qty4 = qty >> 2 << 2; 158 | 159 | float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); 160 | size_t qty_left = qty - qty4; 161 | 162 | float *pVect1 = (float *) pVect1v + qty4; 163 | float *pVect2 = (float *) pVect2v + qty4; 164 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 165 | 166 | return (res + res_tail); 167 | } 168 | #endif 169 | 170 | class L2Space : public SpaceInterface { 171 | 172 | DISTFUNC fstdistfunc_; 173 | size_t data_size_; 174 | size_t dim_; 175 | public: 176 | L2Space(size_t dim) { 177 | fstdistfunc_ = L2Sqr; 178 | #if defined(USE_SSE) || defined(USE_AVX) 179 | if (dim % 16 == 0) 180 | fstdistfunc_ = L2SqrSIMD16Ext; 181 | else if (dim % 4 == 0) 182 | fstdistfunc_ = L2SqrSIMD4Ext; 183 | else if (dim > 16) 184 | fstdistfunc_ = L2SqrSIMD16ExtResiduals; 185 | else if (dim > 4) 186 | fstdistfunc_ = L2SqrSIMD4ExtResiduals; 187 | #endif 188 | dim_ = dim; 189 | data_size_ = dim * sizeof(float); 190 | } 191 | 192 | size_t get_data_size() { 193 | return data_size_; 194 | } 195 | 196 | DISTFUNC get_dist_func() { 197 | return fstdistfunc_; 198 | } 199 | 200 | void *get_dist_func_param() { 201 | return &dim_; 202 | } 203 | 204 | ~L2Space() {} 205 | }; 206 | 207 | static int 208 | L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { 209 | 210 | size_t qty = *((size_t *) qty_ptr); 211 | int res = 0; 212 | unsigned char *a = (unsigned char *) pVect1; 213 | unsigned char *b = (unsigned char *) pVect2; 214 | 215 | qty = qty >> 2; 216 | for (size_t i = 0; i < qty; i++) { 217 | 218 | res += ((*a) - (*b)) * ((*a) - (*b)); 219 | a++; 220 | b++; 221 | res += ((*a) - (*b)) * ((*a) - (*b)); 222 | a++; 223 | b++; 224 | res += ((*a) - (*b)) * ((*a) - (*b)); 225 | a++; 226 | b++; 227 | res += ((*a) - (*b)) * ((*a) - (*b)); 228 | a++; 229 | b++; 230 | } 231 | return (res); 232 | } 233 | 234 | static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { 235 | size_t qty = *((size_t*)qty_ptr); 236 | int res = 0; 237 | unsigned char* a = (unsigned char*)pVect1; 238 | unsigned char* b = (unsigned char*)pVect2; 239 | 240 | for(size_t i = 0; i < qty; i++) 241 | { 242 | res += ((*a) - (*b)) * ((*a) - (*b)); 243 | a++; 244 | b++; 245 | } 246 | return (res); 247 | } 248 | 249 | class L2SpaceI : public SpaceInterface { 250 | 251 | DISTFUNC fstdistfunc_; 252 | size_t data_size_; 253 | size_t dim_; 254 | public: 255 | L2SpaceI(size_t dim) { 256 | if(dim % 4 == 0) { 257 | fstdistfunc_ = L2SqrI4x; 258 | } 259 | else { 260 | fstdistfunc_ = L2SqrI; 261 | } 262 | dim_ = dim; 263 | data_size_ = dim * sizeof(unsigned char); 264 | } 265 | 266 | size_t get_data_size() { 267 | return data_size_; 268 | } 269 | 270 | DISTFUNC get_dist_func() { 271 | return fstdistfunc_; 272 | } 273 | 274 | void *get_dist_func_param() { 275 | return &dim_; 276 | } 277 | 278 | ~L2SpaceI() {} 279 | }; 280 | 281 | 282 | } -------------------------------------------------------------------------------- /include/incremental_hnsw/visited_list_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace hnswlib_incre { 8 | typedef unsigned short int vl_type; 9 | 10 | class VisitedList { 11 | public: 12 | vl_type curV; 13 | vl_type *mass; 14 | unsigned int numelements; 15 | 16 | VisitedList(int numelements1) { 17 | curV = -1; 18 | numelements = numelements1; 19 | mass = new vl_type[numelements]; 20 | } 21 | 22 | void reset() { 23 | curV++; 24 | if (curV == 0) { 25 | memset(mass, 0, sizeof(vl_type) * numelements); 26 | curV++; 27 | } 28 | }; 29 | 30 | ~VisitedList() { delete[] mass; } 31 | }; 32 | /////////////////////////////////////////////////////////// 33 | // 34 | // Class for multi-threaded pool-management of VisitedLists 35 | // 36 | ///////////////////////////////////////////////////////// 37 | 38 | class VisitedListPool { 39 | std::deque pool; 40 | std::mutex poolguard; 41 | int numelements; 42 | 43 | public: 44 | VisitedListPool(int initmaxpools, int numelements1) { 45 | numelements = numelements1; 46 | for (int i = 0; i < initmaxpools; i++) 47 | pool.push_front(new VisitedList(numelements)); 48 | } 49 | 50 | VisitedList *getFreeVisitedList() { 51 | VisitedList *rez; 52 | { 53 | std::unique_lock lock(poolguard); 54 | if (pool.size() > 0) { 55 | rez = pool.front(); 56 | pool.pop_front(); 57 | } else { 58 | rez = new VisitedList(numelements); 59 | } 60 | } 61 | rez->reset(); 62 | return rez; 63 | }; 64 | 65 | void releaseVisitedList(VisitedList *vl) { 66 | std::unique_lock lock(poolguard); 67 | pool.push_front(vl); 68 | }; 69 | 70 | ~VisitedListPool() { 71 | while (pool.size()) { 72 | VisitedList *rez = pool.front(); 73 | pool.pop_front(); 74 | delete rez; 75 | } 76 | }; 77 | }; 78 | } 79 | 80 | -------------------------------------------------------------------------------- /sample_data/deep_10k.fvecs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rutgers-db/SeRF/5666c2a461ed75b10b8b5aafb43b6eb884538ddf/sample_data/deep_10k.fvecs -------------------------------------------------------------------------------- /sample_data/deep_query.fvecs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rutgers-db/SeRF/5666c2a461ed75b10b8b5aafb43b6eb884538ddf/sample_data/deep_query.fvecs -------------------------------------------------------------------------------- /src/base_hnsw/bruteforce.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace base_hnsw { 8 | template 9 | class BruteforceSearch : public AlgorithmInterface { 10 | public: 11 | BruteforceSearch(SpaceInterface *s) {} 12 | BruteforceSearch(SpaceInterface *s, const std::string &location) { 13 | loadIndex(location, s); 14 | } 15 | 16 | BruteforceSearch(SpaceInterface *s, size_t maxElements) { 17 | maxelements_ = maxElements; 18 | data_size_ = s->get_data_size(); 19 | fstdistfunc_ = s->get_dist_func(); 20 | dist_func_param_ = s->get_dist_func_param(); 21 | size_per_element_ = data_size_ + sizeof(labeltype); 22 | data_ = (char *)malloc(maxElements * size_per_element_); 23 | if (data_ == nullptr) 24 | std::runtime_error( 25 | "Not enough memory: BruteforceSearch failed to allocate data"); 26 | cur_element_count = 0; 27 | } 28 | 29 | ~BruteforceSearch() { free(data_); } 30 | 31 | char *data_; 32 | size_t maxelements_; 33 | size_t cur_element_count; 34 | size_t size_per_element_; 35 | 36 | size_t data_size_; 37 | DISTFUNC fstdistfunc_; 38 | void *dist_func_param_; 39 | std::mutex index_lock; 40 | 41 | std::unordered_map dict_external_to_internal; 42 | 43 | void addPoint(const void *datapoint, labeltype label) { 44 | int idx; 45 | { 46 | std::unique_lock lock(index_lock); 47 | 48 | auto search = dict_external_to_internal.find(label); 49 | if (search != dict_external_to_internal.end()) { 50 | idx = search->second; 51 | } else { 52 | if (cur_element_count >= maxelements_) { 53 | throw std::runtime_error( 54 | "The number of elements exceeds the specified limit\n"); 55 | } 56 | idx = cur_element_count; 57 | dict_external_to_internal[label] = idx; 58 | cur_element_count++; 59 | } 60 | } 61 | memcpy(data_ + size_per_element_ * idx + data_size_, &label, 62 | sizeof(labeltype)); 63 | memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); 64 | }; 65 | 66 | void removePoint(labeltype cur_external) { 67 | size_t cur_c = dict_external_to_internal[cur_external]; 68 | 69 | dict_external_to_internal.erase(cur_external); 70 | 71 | labeltype label = 72 | *((labeltype *)(data_ + size_per_element_ * (cur_element_count - 1) + 73 | data_size_)); 74 | dict_external_to_internal[label] = cur_c; 75 | memcpy(data_ + size_per_element_ * cur_c, 76 | data_ + size_per_element_ * (cur_element_count - 1), 77 | data_size_ + sizeof(labeltype)); 78 | cur_element_count--; 79 | } 80 | 81 | std::priority_queue> searchKnn( 82 | const void *query_data, size_t k) const { 83 | std::priority_queue> topResults; 84 | if (cur_element_count == 0) return topResults; 85 | for (int i = 0; i < k; i++) { 86 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, 87 | dist_func_param_); 88 | topResults.push(std::pair( 89 | dist, *((labeltype *)(data_ + size_per_element_ * i + data_size_)))); 90 | } 91 | dist_t lastdist = topResults.top().first; 92 | for (int i = k; i < cur_element_count; i++) { 93 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, 94 | dist_func_param_); 95 | if (dist <= lastdist) { 96 | topResults.push(std::pair( 97 | dist, 98 | *((labeltype *)(data_ + size_per_element_ * i + data_size_)))); 99 | if (topResults.size() > k) topResults.pop(); 100 | lastdist = topResults.top().first; 101 | } 102 | } 103 | return topResults; 104 | }; 105 | 106 | void saveIndex(const std::string &location) { 107 | std::ofstream output(location, std::ios::binary); 108 | std::streampos position; 109 | 110 | writeBinaryPOD(output, maxelements_); 111 | writeBinaryPOD(output, size_per_element_); 112 | writeBinaryPOD(output, cur_element_count); 113 | 114 | output.write(data_, maxelements_ * size_per_element_); 115 | 116 | output.close(); 117 | } 118 | 119 | void loadIndex(const std::string &location, SpaceInterface *s) { 120 | std::ifstream input(location, std::ios::binary); 121 | std::streampos position; 122 | 123 | readBinaryPOD(input, maxelements_); 124 | readBinaryPOD(input, size_per_element_); 125 | readBinaryPOD(input, cur_element_count); 126 | 127 | data_size_ = s->get_data_size(); 128 | fstdistfunc_ = s->get_dist_func(); 129 | dist_func_param_ = s->get_dist_func_param(); 130 | size_per_element_ = data_size_ + sizeof(labeltype); 131 | data_ = (char *)malloc(maxelements_ * size_per_element_); 132 | if (data_ == nullptr) 133 | std::runtime_error( 134 | "Not enough memory: loadIndex failed to allocate data"); 135 | 136 | input.read(data_, maxelements_ * size_per_element_); 137 | 138 | input.close(); 139 | } 140 | }; 141 | } // namespace hnswlib_compose 142 | -------------------------------------------------------------------------------- /src/base_hnsw/hnswlib.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | acknowledge: 5 | update on 09/09/2024, based on the latest hnswlib code from: https://github.com/nmslib/hnswlib/blob/c1b9b79af3d10c6ee7b5d0afa1ce851ae975254c/hnswlib/hnswlib.h 6 | */ 7 | 8 | #ifndef NO_MANUAL_VECTORIZATION 9 | #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) 10 | #define USE_SSE 11 | #ifdef __AVX__ 12 | #define USE_AVX 13 | #ifdef __AVX512F__ 14 | #define USE_AVX512 15 | #endif 16 | #endif 17 | #endif 18 | #endif 19 | 20 | #if defined(USE_AVX) || defined(USE_SSE) 21 | #ifdef _MSC_VER 22 | #include 23 | #include 24 | static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { 25 | __cpuidex(out, eax, ecx); 26 | } 27 | static __int64 xgetbv(unsigned int x) { 28 | return _xgetbv(x); 29 | } 30 | #else 31 | #include 32 | #include 33 | #include 34 | static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { 35 | __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]); 36 | } 37 | static uint64_t xgetbv(unsigned int index) { 38 | uint32_t eax, edx; 39 | __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); 40 | return ((uint64_t)edx << 32) | eax; 41 | } 42 | #endif 43 | 44 | #if defined(USE_AVX512) 45 | #include 46 | #endif 47 | 48 | #if defined(__GNUC__) 49 | #define PORTABLE_ALIGN32 __attribute__((aligned(32))) 50 | #define PORTABLE_ALIGN64 __attribute__((aligned(64))) 51 | #else 52 | #define PORTABLE_ALIGN32 __declspec(align(32)) 53 | #define PORTABLE_ALIGN64 __declspec(align(64)) 54 | #endif 55 | 56 | // Adapted from https://github.com/Mysticial/FeatureDetector 57 | #define _XCR_XFEATURE_ENABLED_MASK 0 58 | 59 | static bool AVXCapable() { 60 | int cpuInfo[4]; 61 | 62 | // CPU support 63 | cpuid(cpuInfo, 0, 0); 64 | int nIds = cpuInfo[0]; 65 | 66 | bool HW_AVX = false; 67 | if (nIds >= 0x00000001) { 68 | cpuid(cpuInfo, 0x00000001, 0); 69 | HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0; 70 | } 71 | 72 | // OS support 73 | cpuid(cpuInfo, 1, 0); 74 | 75 | bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; 76 | bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; 77 | 78 | bool avxSupported = false; 79 | if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { 80 | uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); 81 | avxSupported = (xcrFeatureMask & 0x6) == 0x6; 82 | } 83 | return HW_AVX && avxSupported; 84 | } 85 | 86 | static bool AVX512Capable() { 87 | if (!AVXCapable()) return false; 88 | 89 | int cpuInfo[4]; 90 | 91 | // CPU support 92 | cpuid(cpuInfo, 0, 0); 93 | int nIds = cpuInfo[0]; 94 | 95 | bool HW_AVX512F = false; 96 | if (nIds >= 0x00000007) { // AVX512 Foundation 97 | cpuid(cpuInfo, 0x00000007, 0); 98 | HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0; 99 | } 100 | 101 | // OS support 102 | cpuid(cpuInfo, 1, 0); 103 | 104 | bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; 105 | bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; 106 | 107 | bool avx512Supported = false; 108 | if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { 109 | uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); 110 | avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6; 111 | } 112 | return HW_AVX512F && avx512Supported; 113 | } 114 | #endif 115 | 116 | #include 117 | 118 | #include 119 | #include 120 | #include 121 | 122 | namespace base_hnsw { 123 | typedef size_t labeltype; 124 | 125 | template 126 | class pairGreater { 127 | public: 128 | bool operator()(const T &p1, const T &p2) { return p1.first > p2.first; } 129 | }; 130 | 131 | template 132 | static void writeBinaryPOD(std::ostream &out, const T &podRef) { 133 | out.write((char *)&podRef, sizeof(T)); 134 | } 135 | 136 | template 137 | static void readBinaryPOD(std::istream &in, T &podRef) { 138 | in.read((char *)&podRef, sizeof(T)); 139 | } 140 | 141 | template 142 | using DISTFUNC = MTYPE (*)(const void *, const void *, const void *); 143 | 144 | template 145 | class SpaceInterface { 146 | public: 147 | // virtual void search(void *); 148 | virtual size_t get_data_size() = 0; 149 | 150 | virtual DISTFUNC get_dist_func() = 0; 151 | 152 | virtual void *get_dist_func_param() = 0; 153 | 154 | virtual ~SpaceInterface() {} 155 | }; 156 | 157 | template 158 | class AlgorithmInterface { 159 | public: 160 | void linkNeighbors(const void *data_point, labeltype label, 161 | vector neighbors) {} 162 | void addNeighborPoint(const void *data_point, labeltype label, int level) {} 163 | virtual void addPoint(const void *datapoint, labeltype label) = 0; 164 | virtual std::priority_queue> searchKnn( 165 | const void *, size_t) const = 0; 166 | 167 | // Return k nearest neighbor in the order of closer fist 168 | virtual std::vector> searchKnnCloserFirst( 169 | const void *query_data, size_t k) const; 170 | 171 | virtual void saveIndex(const std::string &location) = 0; 172 | virtual ~AlgorithmInterface() {} 173 | }; 174 | 175 | template 176 | std::vector> 177 | AlgorithmInterface::searchKnnCloserFirst(const void *query_data, 178 | size_t k) const { 179 | std::vector> result; 180 | 181 | // here searchKnn returns the result in the order of further first 182 | auto ret = searchKnn(query_data, k); 183 | { 184 | size_t sz = ret.size(); 185 | result.resize(sz); 186 | while (!ret.empty()) { 187 | result[--sz] = ret.top(); 188 | ret.pop(); 189 | } 190 | } 191 | 192 | return result; 193 | } 194 | 195 | } // namespace base_hnsw 196 | 197 | #include "bruteforce.h" 198 | #include "hnswalg.h" 199 | #include "space_ip.h" 200 | #include "space_l2.h" 201 | -------------------------------------------------------------------------------- /src/base_hnsw/space_ip.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace base_hnsw { 5 | 6 | static float InnerProduct(const void *pVect1, const void *pVect2, 7 | const void *qty_ptr) { 8 | size_t qty = *((size_t *)qty_ptr); 9 | float res = 0; 10 | for (unsigned i = 0; i < qty; i++) { 11 | res += ((float *)pVect1)[i] * ((float *)pVect2)[i]; 12 | } 13 | return (1.0f - res); 14 | } 15 | 16 | #if defined(USE_AVX) 17 | 18 | // Favor using AVX if available. 19 | static float InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, 20 | const void *qty_ptr) { 21 | float PORTABLE_ALIGN32 TmpRes[8]; 22 | float *pVect1 = (float *)pVect1v; 23 | float *pVect2 = (float *)pVect2v; 24 | size_t qty = *((size_t *)qty_ptr); 25 | 26 | size_t qty16 = qty / 16; 27 | size_t qty4 = qty / 4; 28 | 29 | const float *pEnd1 = pVect1 + 16 * qty16; 30 | const float *pEnd2 = pVect1 + 4 * qty4; 31 | 32 | __m256 sum256 = _mm256_set1_ps(0); 33 | 34 | while (pVect1 < pEnd1) { 35 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 36 | 37 | __m256 v1 = _mm256_loadu_ps(pVect1); 38 | pVect1 += 8; 39 | __m256 v2 = _mm256_loadu_ps(pVect2); 40 | pVect2 += 8; 41 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 42 | 43 | v1 = _mm256_loadu_ps(pVect1); 44 | pVect1 += 8; 45 | v2 = _mm256_loadu_ps(pVect2); 46 | pVect2 += 8; 47 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 48 | } 49 | 50 | __m128 v1, v2; 51 | __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), 52 | _mm256_extractf128_ps(sum256, 1)); 53 | 54 | while (pVect1 < pEnd2) { 55 | v1 = _mm_loadu_ps(pVect1); 56 | pVect1 += 4; 57 | v2 = _mm_loadu_ps(pVect2); 58 | pVect2 += 4; 59 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 60 | } 61 | 62 | _mm_store_ps(TmpRes, sum_prod); 63 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 64 | ; 65 | return 1.0f - sum; 66 | } 67 | 68 | #elif defined(USE_SSE) 69 | 70 | static float InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, 71 | const void *qty_ptr) { 72 | float PORTABLE_ALIGN32 TmpRes[8]; 73 | float *pVect1 = (float *)pVect1v; 74 | float *pVect2 = (float *)pVect2v; 75 | size_t qty = *((size_t *)qty_ptr); 76 | 77 | size_t qty16 = qty / 16; 78 | size_t qty4 = qty / 4; 79 | 80 | const float *pEnd1 = pVect1 + 16 * qty16; 81 | const float *pEnd2 = pVect1 + 4 * qty4; 82 | 83 | __m128 v1, v2; 84 | __m128 sum_prod = _mm_set1_ps(0); 85 | 86 | while (pVect1 < pEnd1) { 87 | v1 = _mm_loadu_ps(pVect1); 88 | pVect1 += 4; 89 | v2 = _mm_loadu_ps(pVect2); 90 | pVect2 += 4; 91 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 92 | 93 | v1 = _mm_loadu_ps(pVect1); 94 | pVect1 += 4; 95 | v2 = _mm_loadu_ps(pVect2); 96 | pVect2 += 4; 97 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 98 | 99 | v1 = _mm_loadu_ps(pVect1); 100 | pVect1 += 4; 101 | v2 = _mm_loadu_ps(pVect2); 102 | pVect2 += 4; 103 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 104 | 105 | v1 = _mm_loadu_ps(pVect1); 106 | pVect1 += 4; 107 | v2 = _mm_loadu_ps(pVect2); 108 | pVect2 += 4; 109 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 110 | } 111 | 112 | while (pVect1 < pEnd2) { 113 | v1 = _mm_loadu_ps(pVect1); 114 | pVect1 += 4; 115 | v2 = _mm_loadu_ps(pVect2); 116 | pVect2 += 4; 117 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 118 | } 119 | 120 | _mm_store_ps(TmpRes, sum_prod); 121 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 122 | 123 | return 1.0f - sum; 124 | } 125 | 126 | #endif 127 | 128 | #if defined(USE_AVX) 129 | 130 | static float InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, 131 | const void *qty_ptr) { 132 | float PORTABLE_ALIGN32 TmpRes[8]; 133 | float *pVect1 = (float *)pVect1v; 134 | float *pVect2 = (float *)pVect2v; 135 | size_t qty = *((size_t *)qty_ptr); 136 | 137 | size_t qty16 = qty / 16; 138 | 139 | const float *pEnd1 = pVect1 + 16 * qty16; 140 | 141 | __m256 sum256 = _mm256_set1_ps(0); 142 | 143 | while (pVect1 < pEnd1) { 144 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 145 | 146 | __m256 v1 = _mm256_loadu_ps(pVect1); 147 | pVect1 += 8; 148 | __m256 v2 = _mm256_loadu_ps(pVect2); 149 | pVect2 += 8; 150 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 151 | 152 | v1 = _mm256_loadu_ps(pVect1); 153 | pVect1 += 8; 154 | v2 = _mm256_loadu_ps(pVect2); 155 | pVect2 += 8; 156 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 157 | } 158 | 159 | _mm256_store_ps(TmpRes, sum256); 160 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + 161 | TmpRes[5] + TmpRes[6] + TmpRes[7]; 162 | 163 | return 1.0f - sum; 164 | } 165 | 166 | #elif defined(USE_SSE) 167 | 168 | static float InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, 169 | const void *qty_ptr) { 170 | float PORTABLE_ALIGN32 TmpRes[8]; 171 | float *pVect1 = (float *)pVect1v; 172 | float *pVect2 = (float *)pVect2v; 173 | size_t qty = *((size_t *)qty_ptr); 174 | 175 | size_t qty16 = qty / 16; 176 | 177 | const float *pEnd1 = pVect1 + 16 * qty16; 178 | 179 | __m128 v1, v2; 180 | __m128 sum_prod = _mm_set1_ps(0); 181 | 182 | while (pVect1 < pEnd1) { 183 | v1 = _mm_loadu_ps(pVect1); 184 | pVect1 += 4; 185 | v2 = _mm_loadu_ps(pVect2); 186 | pVect2 += 4; 187 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 188 | 189 | v1 = _mm_loadu_ps(pVect1); 190 | pVect1 += 4; 191 | v2 = _mm_loadu_ps(pVect2); 192 | pVect2 += 4; 193 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 194 | 195 | v1 = _mm_loadu_ps(pVect1); 196 | pVect1 += 4; 197 | v2 = _mm_loadu_ps(pVect2); 198 | pVect2 += 4; 199 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 200 | 201 | v1 = _mm_loadu_ps(pVect1); 202 | pVect1 += 4; 203 | v2 = _mm_loadu_ps(pVect2); 204 | pVect2 += 4; 205 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 206 | } 207 | _mm_store_ps(TmpRes, sum_prod); 208 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 209 | 210 | return 1.0f - sum; 211 | } 212 | 213 | #endif 214 | 215 | #if defined(USE_SSE) || defined(USE_AVX) 216 | static float InnerProductSIMD16ExtResiduals(const void *pVect1v, 217 | const void *pVect2v, 218 | const void *qty_ptr) { 219 | size_t qty = *((size_t *)qty_ptr); 220 | size_t qty16 = qty >> 4 << 4; 221 | float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); 222 | float *pVect1 = (float *)pVect1v + qty16; 223 | float *pVect2 = (float *)pVect2v + qty16; 224 | 225 | size_t qty_left = qty - qty16; 226 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left); 227 | return res + res_tail - 1.0f; 228 | } 229 | 230 | static float InnerProductSIMD4ExtResiduals(const void *pVect1v, 231 | const void *pVect2v, 232 | const void *qty_ptr) { 233 | size_t qty = *((size_t *)qty_ptr); 234 | size_t qty4 = qty >> 2 << 2; 235 | 236 | float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); 237 | size_t qty_left = qty - qty4; 238 | 239 | float *pVect1 = (float *)pVect1v + qty4; 240 | float *pVect2 = (float *)pVect2v + qty4; 241 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left); 242 | 243 | return res + res_tail - 1.0f; 244 | } 245 | #endif 246 | 247 | class InnerProductSpace : public SpaceInterface { 248 | DISTFUNC fstdistfunc_; 249 | size_t data_size_; 250 | size_t dim_; 251 | 252 | public: 253 | InnerProductSpace(size_t dim) { 254 | fstdistfunc_ = InnerProduct; 255 | #if defined(USE_AVX) || defined(USE_SSE) 256 | if (dim % 16 == 0) 257 | fstdistfunc_ = InnerProductSIMD16Ext; 258 | else if (dim % 4 == 0) 259 | fstdistfunc_ = InnerProductSIMD4Ext; 260 | else if (dim > 16) 261 | fstdistfunc_ = InnerProductSIMD16ExtResiduals; 262 | else if (dim > 4) 263 | fstdistfunc_ = InnerProductSIMD4ExtResiduals; 264 | #endif 265 | dim_ = dim; 266 | data_size_ = dim * sizeof(float); 267 | } 268 | 269 | size_t get_data_size() { return data_size_; } 270 | 271 | DISTFUNC get_dist_func() { return fstdistfunc_; } 272 | 273 | void *get_dist_func_param() { return &dim_; } 274 | 275 | ~InnerProductSpace() {} 276 | }; 277 | 278 | } // namespace hnswlib_compose 279 | -------------------------------------------------------------------------------- /src/base_hnsw/space_l2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | /* 5 | acknowledge: 6 | update on 09/09/2024, based on the latest hnswlib code from: https://github.com/nmslib/hnswlib/blob/c1b9b79af3d10c6ee7b5d0afa1ce851ae975254c/hnswlib/space_l2.h 7 | */ 8 | 9 | namespace base_hnsw { 10 | 11 | static float 12 | L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 13 | float *pVect1 = (float *) pVect1v; 14 | float *pVect2 = (float *) pVect2v; 15 | size_t qty = *((size_t *) qty_ptr); 16 | 17 | float res = 0; 18 | for (size_t i = 0; i < qty; i++) { 19 | float t = *pVect1 - *pVect2; 20 | pVect1++; 21 | pVect2++; 22 | res += t * t; 23 | } 24 | return (res); 25 | } 26 | 27 | #if defined(USE_AVX512) 28 | 29 | // Favor using AVX512 if available. 30 | static float 31 | L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 32 | float *pVect1 = (float *) pVect1v; 33 | float *pVect2 = (float *) pVect2v; 34 | size_t qty = *((size_t *) qty_ptr); 35 | float PORTABLE_ALIGN64 TmpRes[16]; 36 | size_t qty16 = qty >> 4; 37 | 38 | const float *pEnd1 = pVect1 + (qty16 << 4); 39 | 40 | __m512 diff, v1, v2; 41 | __m512 sum = _mm512_set1_ps(0); 42 | 43 | while (pVect1 < pEnd1) { 44 | v1 = _mm512_loadu_ps(pVect1); 45 | pVect1 += 16; 46 | v2 = _mm512_loadu_ps(pVect2); 47 | pVect2 += 16; 48 | diff = _mm512_sub_ps(v1, v2); 49 | // sum = _mm512_fmadd_ps(diff, diff, sum); 50 | sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); 51 | } 52 | 53 | _mm512_store_ps(TmpRes, sum); 54 | float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + 55 | TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + 56 | TmpRes[13] + TmpRes[14] + TmpRes[15]; 57 | 58 | return (res); 59 | } 60 | #endif 61 | 62 | #if defined(USE_AVX) 63 | 64 | // Favor using AVX if available. 65 | static float 66 | L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 67 | float *pVect1 = (float *) pVect1v; 68 | float *pVect2 = (float *) pVect2v; 69 | size_t qty = *((size_t *) qty_ptr); 70 | float PORTABLE_ALIGN32 TmpRes[8]; 71 | size_t qty16 = qty >> 4; 72 | 73 | const float *pEnd1 = pVect1 + (qty16 << 4); 74 | 75 | __m256 diff, v1, v2; 76 | __m256 sum = _mm256_set1_ps(0); 77 | 78 | while (pVect1 < pEnd1) { 79 | v1 = _mm256_loadu_ps(pVect1); 80 | pVect1 += 8; 81 | v2 = _mm256_loadu_ps(pVect2); 82 | pVect2 += 8; 83 | diff = _mm256_sub_ps(v1, v2); 84 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 85 | 86 | v1 = _mm256_loadu_ps(pVect1); 87 | pVect1 += 8; 88 | v2 = _mm256_loadu_ps(pVect2); 89 | pVect2 += 8; 90 | diff = _mm256_sub_ps(v1, v2); 91 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 92 | } 93 | 94 | _mm256_store_ps(TmpRes, sum); 95 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 96 | } 97 | 98 | #endif 99 | 100 | #if defined(USE_SSE) 101 | 102 | static float 103 | L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 104 | float *pVect1 = (float *) pVect1v; 105 | float *pVect2 = (float *) pVect2v; 106 | size_t qty = *((size_t *) qty_ptr); 107 | float PORTABLE_ALIGN32 TmpRes[8]; 108 | size_t qty16 = qty >> 4; 109 | 110 | const float *pEnd1 = pVect1 + (qty16 << 4); 111 | 112 | __m128 diff, v1, v2; 113 | __m128 sum = _mm_set1_ps(0); 114 | 115 | while (pVect1 < pEnd1) { 116 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 117 | v1 = _mm_loadu_ps(pVect1); 118 | pVect1 += 4; 119 | v2 = _mm_loadu_ps(pVect2); 120 | pVect2 += 4; 121 | diff = _mm_sub_ps(v1, v2); 122 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 123 | 124 | v1 = _mm_loadu_ps(pVect1); 125 | pVect1 += 4; 126 | v2 = _mm_loadu_ps(pVect2); 127 | pVect2 += 4; 128 | diff = _mm_sub_ps(v1, v2); 129 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 130 | 131 | v1 = _mm_loadu_ps(pVect1); 132 | pVect1 += 4; 133 | v2 = _mm_loadu_ps(pVect2); 134 | pVect2 += 4; 135 | diff = _mm_sub_ps(v1, v2); 136 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 137 | 138 | v1 = _mm_loadu_ps(pVect1); 139 | pVect1 += 4; 140 | v2 = _mm_loadu_ps(pVect2); 141 | pVect2 += 4; 142 | diff = _mm_sub_ps(v1, v2); 143 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 144 | } 145 | 146 | _mm_store_ps(TmpRes, sum); 147 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 148 | } 149 | #endif 150 | 151 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) 152 | static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; 153 | 154 | static float 155 | L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 156 | size_t qty = *((size_t *) qty_ptr); 157 | size_t qty16 = qty >> 4 << 4; 158 | float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); 159 | float *pVect1 = (float *) pVect1v + qty16; 160 | float *pVect2 = (float *) pVect2v + qty16; 161 | 162 | size_t qty_left = qty - qty16; 163 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 164 | return (res + res_tail); 165 | } 166 | #endif 167 | 168 | 169 | #if defined(USE_SSE) 170 | static float 171 | L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 172 | float PORTABLE_ALIGN32 TmpRes[8]; 173 | float *pVect1 = (float *) pVect1v; 174 | float *pVect2 = (float *) pVect2v; 175 | size_t qty = *((size_t *) qty_ptr); 176 | 177 | 178 | size_t qty4 = qty >> 2; 179 | 180 | const float *pEnd1 = pVect1 + (qty4 << 2); 181 | 182 | __m128 diff, v1, v2; 183 | __m128 sum = _mm_set1_ps(0); 184 | 185 | while (pVect1 < pEnd1) { 186 | v1 = _mm_loadu_ps(pVect1); 187 | pVect1 += 4; 188 | v2 = _mm_loadu_ps(pVect2); 189 | pVect2 += 4; 190 | diff = _mm_sub_ps(v1, v2); 191 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 192 | } 193 | _mm_store_ps(TmpRes, sum); 194 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 195 | } 196 | 197 | static float 198 | L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 199 | size_t qty = *((size_t *) qty_ptr); 200 | size_t qty4 = qty >> 2 << 2; 201 | 202 | float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); 203 | size_t qty_left = qty - qty4; 204 | 205 | float *pVect1 = (float *) pVect1v + qty4; 206 | float *pVect2 = (float *) pVect2v + qty4; 207 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 208 | 209 | return (res + res_tail); 210 | } 211 | #endif 212 | 213 | class L2Space : public SpaceInterface { 214 | DISTFUNC fstdistfunc_; 215 | size_t data_size_; 216 | size_t dim_; 217 | 218 | public: 219 | L2Space(size_t dim) { 220 | fstdistfunc_ = L2Sqr; 221 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) 222 | #if defined(USE_AVX512) 223 | if (AVX512Capable()) 224 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; 225 | else if (AVXCapable()) 226 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; 227 | #elif defined(USE_AVX) 228 | if (AVXCapable()) 229 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; 230 | #endif 231 | 232 | if (dim % 16 == 0) 233 | fstdistfunc_ = L2SqrSIMD16Ext; 234 | else if (dim % 4 == 0) 235 | fstdistfunc_ = L2SqrSIMD4Ext; 236 | else if (dim > 16) 237 | fstdistfunc_ = L2SqrSIMD16ExtResiduals; 238 | else if (dim > 4) 239 | fstdistfunc_ = L2SqrSIMD4ExtResiduals; 240 | #endif 241 | dim_ = dim; 242 | data_size_ = dim * sizeof(float); 243 | } 244 | 245 | size_t get_data_size() { 246 | return data_size_; 247 | } 248 | 249 | DISTFUNC get_dist_func() { 250 | return fstdistfunc_; 251 | } 252 | 253 | void *get_dist_func_param() { 254 | return &dim_; 255 | } 256 | 257 | ~L2Space() {} 258 | }; 259 | 260 | static int 261 | L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { 262 | size_t qty = *((size_t *) qty_ptr); 263 | int res = 0; 264 | unsigned char *a = (unsigned char *) pVect1; 265 | unsigned char *b = (unsigned char *) pVect2; 266 | 267 | qty = qty >> 2; 268 | for (size_t i = 0; i < qty; i++) { 269 | res += ((*a) - (*b)) * ((*a) - (*b)); 270 | a++; 271 | b++; 272 | res += ((*a) - (*b)) * ((*a) - (*b)); 273 | a++; 274 | b++; 275 | res += ((*a) - (*b)) * ((*a) - (*b)); 276 | a++; 277 | b++; 278 | res += ((*a) - (*b)) * ((*a) - (*b)); 279 | a++; 280 | b++; 281 | } 282 | return (res); 283 | } 284 | 285 | static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { 286 | size_t qty = *((size_t*)qty_ptr); 287 | int res = 0; 288 | unsigned char* a = (unsigned char*)pVect1; 289 | unsigned char* b = (unsigned char*)pVect2; 290 | 291 | for (size_t i = 0; i < qty; i++) { 292 | res += ((*a) - (*b)) * ((*a) - (*b)); 293 | a++; 294 | b++; 295 | } 296 | return (res); 297 | } 298 | 299 | class L2SpaceI : public SpaceInterface { 300 | DISTFUNC fstdistfunc_; 301 | size_t data_size_; 302 | size_t dim_; 303 | 304 | public: 305 | L2SpaceI(size_t dim) { 306 | if (dim % 4 == 0) { 307 | fstdistfunc_ = L2SqrI4x; 308 | } else { 309 | fstdistfunc_ = L2SqrI; 310 | } 311 | dim_ = dim; 312 | data_size_ = dim * sizeof(unsigned char); 313 | } 314 | 315 | size_t get_data_size() { 316 | return data_size_; 317 | } 318 | 319 | DISTFUNC get_dist_func() { 320 | return fstdistfunc_; 321 | } 322 | 323 | void *get_dist_func_param() { 324 | return &dim_; 325 | } 326 | 327 | ~L2SpaceI() {} 328 | }; 329 | 330 | } // namespace hnswlib_compose -------------------------------------------------------------------------------- /src/base_hnsw/visited_list_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace base_hnsw { 8 | typedef unsigned short int vl_type; 9 | 10 | class VisitedList { 11 | public: 12 | vl_type curV; 13 | vl_type *mass; 14 | unsigned int numelements; 15 | 16 | VisitedList(int numelements1) { 17 | curV = -1; 18 | numelements = numelements1; 19 | mass = new vl_type[numelements]; 20 | } 21 | 22 | void reset() { 23 | curV++; 24 | if (curV == 0) { 25 | memset(mass, 0, sizeof(vl_type) * numelements); 26 | curV++; 27 | } 28 | }; 29 | 30 | ~VisitedList() { delete[] mass; } 31 | }; 32 | /////////////////////////////////////////////////////////// 33 | // 34 | // Class for multi-threaded pool-management of VisitedLists 35 | // 36 | ///////////////////////////////////////////////////////// 37 | 38 | class VisitedListPool { 39 | std::deque pool; 40 | std::mutex poolguard; 41 | int numelements; 42 | 43 | public: 44 | VisitedListPool(int initmaxpools, int numelements1) { 45 | numelements = numelements1; 46 | for (int i = 0; i < initmaxpools; i++) 47 | pool.push_front(new VisitedList(numelements)); 48 | } 49 | 50 | VisitedList *getFreeVisitedList() { 51 | VisitedList *rez; 52 | { 53 | std::unique_lock lock(poolguard); 54 | if (pool.size() > 0) { 55 | rez = pool.front(); 56 | pool.pop_front(); 57 | } else { 58 | rez = new VisitedList(numelements); 59 | } 60 | } 61 | rez->reset(); 62 | return rez; 63 | }; 64 | 65 | void releaseVisitedList(VisitedList *vl) { 66 | std::unique_lock lock(poolguard); 67 | pool.push_front(vl); 68 | }; 69 | 70 | ~VisitedListPool() { 71 | while (pool.size()) { 72 | VisitedList *rez = pool.front(); 73 | pool.pop_front(); 74 | delete rez; 75 | } 76 | }; 77 | }; 78 | } // namespace hnswlib_compose 79 | -------------------------------------------------------------------------------- /src/baselines/knn_first_hnsw.h: -------------------------------------------------------------------------------- 1 | /** 2 | * baseline #1, calculate nearest range query first, then search among the 3 | * range. 4 | * 5 | * Author: Chaoji Zuo 6 | * Date: Nov 13, 2021 7 | * Email: chaoji.zuo@rutgers.edu 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "incremental_hnsw/hnswlib.h" 17 | #include "index_base.h" 18 | #include "utils.h" 19 | 20 | using std::cout; 21 | using std::endl; 22 | using std::string; 23 | using std::vector; 24 | 25 | void buildKNNFirstGraph(const vector> &nodes, 26 | hnswlib_incre::HierarchicalNSW &alg_hnsw) { 27 | #pragma omp parallel for 28 | for (size_t i = 0; i < nodes.size(); ++i) { 29 | alg_hnsw.addPoint(nodes[i].data(), i); 30 | } 31 | } 32 | 33 | void addHNSWPointsSubgraph(const vector> &nodes, 34 | hnswlib_incre::HierarchicalNSW *alg_hnsw, 35 | const int start, const int end) { 36 | #pragma omp parallel for 37 | for (size_t i = start; i <= end; ++i) { 38 | alg_hnsw->addPoint(nodes[i].data(), i); 39 | } 40 | } 41 | 42 | void buildKNNFirstGraphSingleThread( 43 | const vector> &nodes, 44 | hnswlib_incre::HierarchicalNSW &alg_hnsw) { 45 | for (size_t i = 0; i < nodes.size(); ++i) { 46 | alg_hnsw.addPoint(nodes[i].data(), i); 47 | } 48 | } 49 | 50 | void buildKNNFirstGraphSingleThread( 51 | const vector> &nodes, 52 | hnswlib_incre::HierarchicalNSW *alg_hnsw) { 53 | for (size_t i = 0; i < nodes.size(); ++i) { 54 | alg_hnsw->addPoint(nodes[i].data(), i); 55 | } 56 | } 57 | 58 | vector KNNFirstRangeSearch( 59 | const hnswlib_incre::HierarchicalNSW &alg_hnsw, 60 | const vector &query, const int l_bound, const int r_bound, 61 | const int query_k) { 62 | // K to run nndescent 63 | 64 | vector result_in_range; 65 | auto res = 66 | alg_hnsw.searchKnnCloserFirst(query.data(), query_k, l_bound, r_bound); 67 | for (size_t j = 0; j < res.size(); j++) { 68 | int val = res[j].second; 69 | result_in_range.emplace_back(val); 70 | } 71 | return result_in_range; 72 | } 73 | 74 | vector KNNFirstRangeSearchFixedEF( 75 | hnswlib_incre::HierarchicalNSW &alg_hnsw, const vector &query, 76 | const int l_bound, const int r_bound, const int query_k) { 77 | // K to run nndescent 78 | 79 | vector result_in_range; 80 | auto res = alg_hnsw.searchKnnCloserFirst(query.data(), query_k, l_bound, 81 | r_bound, true); 82 | for (size_t j = 0; j < res.size(); j++) { 83 | int val = res[j].second; 84 | result_in_range.emplace_back(val); 85 | } 86 | 87 | #ifdef LOG_DEBUG_MODE 88 | print_set(result_in_range); 89 | cout << l_bound << "," << r_bound << endl; 90 | assert(false); 91 | #endif 92 | 93 | return result_in_range; 94 | } 95 | 96 | vector KNNFirstRangeSearchFixedEF( 97 | hnswlib_incre::HierarchicalNSW *alg_hnsw, const vector &query, 98 | const int l_bound, const int r_bound, const int query_k, 99 | const int search_ef) { 100 | // K to run nndescent 101 | 102 | alg_hnsw->setEf(search_ef); 103 | 104 | vector result_in_range; 105 | auto res = alg_hnsw->searchKnnCloserFirst(query.data(), query_k, l_bound, 106 | r_bound, true); 107 | for (size_t j = 0; j < res.size(); j++) { 108 | int val = res[j].second; 109 | result_in_range.emplace_back(val); 110 | } 111 | return result_in_range; 112 | } 113 | 114 | class KnnFirstWrapper : BaseIndex { 115 | public: 116 | KnnFirstWrapper(const DataWrapper *data) : BaseIndex(data) { 117 | index_info = new IndexInfo(); 118 | index_info->index_version_type = "KnnFirst-hnsw"; 119 | }; 120 | 121 | IndexInfo *index_info; 122 | 123 | hnswlib_incre::HierarchicalNSW *hnsw_index; 124 | hnswlib_incre::L2Space *space; 125 | 126 | void countNeighbrs() { 127 | int node_amount = 0; 128 | 129 | for (unsigned idx = 0; idx < data_wrapper->data_size; idx++) { 130 | hnswlib_incre::linklistsizeint *linklist; 131 | linklist = hnsw_index->get_linklist0(idx); 132 | size_t linklist_count = hnsw_index->getListCount(linklist); 133 | node_amount += linklist_count; 134 | } 135 | index_info->nodes_amount = node_amount; 136 | index_info->avg_forward_nns = (float)node_amount / data_wrapper->data_size; 137 | cout << "# of Avg. Neighbors: " << index_info->avg_forward_nns << endl; 138 | } 139 | 140 | void buildIndex(const IndexParams *index_params) override { 141 | cout << "Building baseline graph: " << index_info->index_version_type 142 | << endl; 143 | 144 | timeval tt1, tt2; 145 | gettimeofday(&tt1, NULL); 146 | space = new hnswlib_incre::L2Space(data_wrapper->data_dim); 147 | 148 | hnsw_index = new hnswlib_incre::HierarchicalNSW( 149 | space, 2 * data_wrapper->data_size, index_params->K, 150 | index_params->ef_construction); 151 | for (size_t i = 0; i < data_wrapper->data_size; ++i) { 152 | hnsw_index->addPoint(data_wrapper->nodes.at(i).data(), i); 153 | } 154 | cout << "Done" << endl; 155 | index_info->index_time = CountTime(tt1, tt2); 156 | countNeighbrs(); 157 | } 158 | 159 | vector rangeFilteringSearchInRange( 160 | const SearchParams *search_params, SearchInfo *search_info, 161 | const vector &query, 162 | const std::pair query_bound) override { 163 | return rangeFilteringSearchOutBound(search_params, search_info, query, 164 | query_bound); 165 | } 166 | 167 | vector rangeFilteringSearchOutBound( 168 | const SearchParams *search_params, SearchInfo *search_info, 169 | const vector &query, 170 | const std::pair query_bound) override { 171 | timeval tt1, tt2; 172 | 173 | hnsw_index->search_info = search_info; 174 | gettimeofday(&tt1, NULL); 175 | 176 | hnsw_index->setEf(search_params->search_ef); 177 | 178 | vector result_in_range; 179 | auto res = hnsw_index->searchKnnCloserFirst( 180 | query.data(), search_params->query_K, query_bound.first, 181 | query_bound.second, true); 182 | for (size_t j = 0; j < res.size(); j++) { 183 | int val = res[j].second; 184 | result_in_range.emplace_back(val); 185 | } 186 | gettimeofday(&tt2, NULL); 187 | CountTime(tt1, tt2, search_info->internal_search_time); 188 | return result_in_range; 189 | } 190 | 191 | void saveIndex(const string &save_path) { hnsw_index->saveIndex(save_path); } 192 | 193 | ~KnnFirstWrapper() { 194 | delete hnsw_index; 195 | delete index_info; 196 | delete space; 197 | } 198 | }; 199 | 200 | void execute_knn_first_search(KnnFirstWrapper &index, 201 | BaseIndex::SearchInfo &search_info, 202 | const DataWrapper &data_wrapper, 203 | const vector &searchef_para_range_list) { 204 | timeval tt3, tt4; 205 | for (auto one_searchef : searchef_para_range_list) { 206 | gettimeofday(&tt3, NULL); 207 | for (int idx = 0; idx < data_wrapper.query_ids.size(); idx++) { 208 | int one_id = data_wrapper.query_ids.at(idx); 209 | BaseIndex::SearchParams s_params; 210 | s_params.query_K = data_wrapper.query_k; 211 | s_params.search_ef = one_searchef; 212 | s_params.control_batch_threshold = 1; 213 | s_params.query_range = data_wrapper.query_ranges.at(idx).second - 214 | data_wrapper.query_ranges.at(idx).first; 215 | auto res = index.rangeFilteringSearchOutBound( 216 | &s_params, &search_info, data_wrapper.querys.at(one_id), 217 | data_wrapper.query_ranges.at(idx)); 218 | search_info.precision = 219 | countPrecision(data_wrapper.groundtruth.at(idx), res); 220 | search_info.approximate_ratio = countApproximationRatio( 221 | data_wrapper.nodes, data_wrapper.groundtruth.at(idx), res, 222 | data_wrapper.querys.at(one_id)); 223 | 224 | // cout << data_wrapper.query_ranges.at(idx).first << " " 225 | // << data_wrapper.query_ranges.at(idx).second << endl; 226 | // print_set(res); 227 | // print_set(data_wrapper.groundtruth.at(idx)); 228 | // cout << endl; 229 | 230 | search_info.RecordOneQuery(&s_params); 231 | } 232 | 233 | logTime(tt3, tt4, "total query time"); 234 | } 235 | } 236 | 237 | void execute_knn_first_search_groundtruth_wrapper( 238 | KnnFirstWrapper &index, BaseIndex::SearchInfo &search_info, 239 | const DataWrapper &data_wrapper, const DataWrapper &groundtruth_wrapper, 240 | const vector &searchef_para_range_list) { 241 | timeval tt3, tt4; 242 | for (auto one_searchef : searchef_para_range_list) { 243 | gettimeofday(&tt3, NULL); 244 | for (int idx = 0; idx < groundtruth_wrapper.query_ids.size(); idx++) { 245 | int one_id = groundtruth_wrapper.query_ids.at(idx); 246 | BaseIndex::SearchParams s_params; 247 | s_params.query_K = data_wrapper.query_k; 248 | s_params.search_ef = one_searchef; 249 | s_params.control_batch_threshold = 1; 250 | s_params.query_range = groundtruth_wrapper.query_ranges.at(idx).second - 251 | groundtruth_wrapper.query_ranges.at(idx).first; 252 | auto res = index.rangeFilteringSearchOutBound( 253 | &s_params, &search_info, data_wrapper.querys.at(one_id), 254 | groundtruth_wrapper.query_ranges.at(idx)); 255 | search_info.precision = 256 | countPrecision(groundtruth_wrapper.groundtruth.at(idx), res); 257 | search_info.approximate_ratio = countApproximationRatio( 258 | data_wrapper.nodes, groundtruth_wrapper.groundtruth.at(idx), res, 259 | data_wrapper.querys.at(one_id)); 260 | 261 | search_info.RecordOneQuery(&s_params); 262 | } 263 | 264 | logTime(tt3, tt4, "total query time"); 265 | } 266 | } -------------------------------------------------------------------------------- /src/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.9) 2 | 3 | find_package(OpenMP REQUIRED) 4 | if (OPENMP_FOUND) 5 | set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 6 | set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 7 | set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 8 | endif() 9 | 10 | file(GLOB_RECURSE native_srcs *.cc *.cpp) 11 | 12 | include_directories(${PROJECT_SOURCE_DIR}/include/common) 13 | 14 | add_library(UTIL STATIC ${native_srcs} ) 15 | 16 | target_compile_options(UTIL PRIVATE -Wall ${OpenMP_CXX_FLAGS}) 17 | target_link_libraries(UTIL ${OpenMP_CXX_FLAGS}) 18 | target_link_libraries(UTIL OpenMP::OpenMP_CXX) -------------------------------------------------------------------------------- /src/common/data_processing.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * @file data_processing.cc 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Functions for processing data, generating querys and groundtruth 5 | * @date 2023-06-19 6 | * 7 | * @copyright Copyright (c) 2023 8 | */ 9 | 10 | #include "data_processing.h" 11 | 12 | #include "data_wrapper.h" 13 | #include "omp.h" 14 | 15 | using std::pair; 16 | 17 | // void SynthesizeQuerys(const vector> &nodes, 18 | // vector> &querys, const int query_num) { 19 | // int dim = nodes.front().size(); 20 | // std::default_random_engine e; 21 | // std::uniform_int_distribution u(0, nodes.size() - 1); 22 | // querys.clear(); 23 | // querys.resize(query_num); 24 | 25 | // for (unsigned n = 0; n < query_num; n++) { 26 | // for (unsigned i = 0; i < dim; i++) { 27 | // int select_idx = u(e); 28 | // querys[n].emplace_back(nodes[select_idx][i]); 29 | // } 30 | // } 31 | // } 32 | 33 | // vector greedyNearest(const vector> &dpts, 34 | // const vector query, const int k_smallest) { 35 | // std::priority_queue> top_candidates; 36 | // float lower_bound = _INT_MAX; 37 | // for (size_t i = 0; i < dpts.size(); i++) { 38 | // float dist = EuclideanDistance(query, dpts[i]); 39 | // if (top_candidates.size() < k_smallest || dist < lower_bound) { 40 | // top_candidates.push(std::make_pair(dist, i)); 41 | // if (top_candidates.size() > k_smallest) { 42 | // top_candidates.pop(); 43 | // } 44 | 45 | // lower_bound = top_candidates.top().first; 46 | // } 47 | // } 48 | // vector res; 49 | // while (!top_candidates.empty()) { 50 | // res.emplace_back(top_candidates.top().second); 51 | // top_candidates.pop(); 52 | // } 53 | // std::reverse(res.begin(), res.end()); 54 | // return res; 55 | // } 56 | 57 | // // generate range filtering querys and calculate groundtruth 58 | // void calculateGroundtruth(DataWrapper &data_wrapper) { 59 | // std::default_random_engine e; 60 | // vector> query_ranges; 61 | // vector> groundtruth; 62 | // vector query_ids; 63 | // timeval t1, t2; 64 | // double accu_time = 0.0; 65 | 66 | // vector query_range_list; 67 | // float scale = 0.05; 68 | // if (data_wrapper.dataset == "local") { 69 | // scale = 0.1; 70 | // } 71 | // int init_range = data_wrapper.data_size * scale; 72 | // while (init_range < data_wrapper.data_size) { 73 | // query_range_list.emplace_back(init_range); 74 | // init_range += data_wrapper.data_size * scale; 75 | // } 76 | // #ifdef LOG_DEBUG_MODE 77 | // if (data_wrapper.dataset == "local") { 78 | // query_range_list.erase( 79 | // query_range_list.begin(), 80 | // query_range_list.begin() + 4 * query_range_list.size() / 9); 81 | // } 82 | // #endif 83 | 84 | // cout << "Generating Groundtruth..."; 85 | // // generate groundtruth 86 | 87 | // for (auto range : query_range_list) { 88 | // if (range > runner.data_size - 100) { 89 | // break; 90 | // } 91 | // uniform_int_distribution u_lbound(0, 92 | // data_wrapper.data_size - range - 80); 93 | // for (int i = 0; i < data_wrapper.querys.size(); i++) { 94 | // int l_bound = u_lbound(e); 95 | // int r_bound = l_bound + range; 96 | // int search_key_range = r_bound - l_bound; 97 | // if (data_wrapper.real_keys) 98 | // search_key_range = data_wrapper.nodes_keys.at(r_bound) - 99 | // data_wrapper.nodes_keys.at(l_bound); 100 | // for (auto query_K : query_k_range_list) { 101 | // query_ranges.emplace_back(make_pair(l_bound, r_bound)); 102 | // double greedy_time; 103 | // gettimeofday(&t1, NULL); 104 | // auto gt = greedyNearest(data_wrapper.nodes, runner.querys.at(i), 105 | // l_bound, r_bound, query_K); 106 | // gettimeofday(&t2, NULL); 107 | // CountTime(t1, t2, greedy_time); 108 | // // SaveToCSVRow("../exp_result/exp2-amarel-scalability-greedy-baseline-" 109 | // // + 110 | // // to_string(query_K) + "-" + to_string(nodes.size()) + 111 | // // "-" + dataset + ".csv", 112 | // // i, l_bound, r_bound, r_bound - l_bound, 113 | // // search_key_range, query_K, greedy_time, gt); 114 | // groundtruth.emplace_back(gt); 115 | // query_ids.emplace_back(i); 116 | // } 117 | // } 118 | // } 119 | 120 | // cout << " Done!" << endl << "Groundtruth Time: " << accu_time << endl; 121 | 122 | // data_wrapper.groundtruth = std::move(groundtruth); 123 | // data_wrapper.query_ids = std::move(query_ids); 124 | // data_wrapper.query_ranges = std::move(query_ranges); 125 | // } 126 | 127 | // // Get Groundtruth for half bounded query 128 | // void calculateGroundtruthHalfBounded(ExpRunner &runner, bool is_save = false) { 129 | // default_random_engine e; 130 | // vector> query_ranges; 131 | // vector> groundtruth; 132 | // vector query_ids; 133 | 134 | // timeval t1, t2; 135 | 136 | // vector query_k_range_list = {10}; 137 | // vector query_range_list; 138 | // int init_range = runner.data_size * 0.05; 139 | // while (init_range < runner.data_size) { 140 | // query_range_list.emplace_back(init_range); 141 | // init_range += runner.data_size * 0.05; 142 | // } 143 | 144 | // cout << "Generating Groundtruth..."; 145 | // // generate groundtruth 146 | // for (auto range : query_range_list) { 147 | // if (range > runner.data_size - 100) { 148 | // break; 149 | // } 150 | // for (int i = 0; i < runner.querys.size(); i++) { 151 | // // int l_bound = range - 1; 152 | // // int r_bound = runner.data_size - 1; 153 | 154 | // int l_bound = 0; 155 | // int r_bound = range; 156 | 157 | // int search_key_range = r_bound - l_bound; 158 | // if (runner.real_keys) 159 | // search_key_range = 160 | // runner.nodes_keys.at(r_bound) - runner.nodes_keys.at(l_bound); 161 | // for (auto query_K : query_k_range_list) { 162 | // query_ranges.emplace_back(make_pair(l_bound, r_bound)); 163 | // double greedy_time; 164 | // gettimeofday(&t1, NULL); 165 | // auto gt = greedyNearest(runner.nodes, runner.querys.at(i), l_bound, 166 | // r_bound, query_K); 167 | // gettimeofday(&t2, NULL); 168 | // CountTime(t1, t2, greedy_time); 169 | // // SaveToCSVRow("../exp_result/exp2-amarel-scalability-greedy-baseline-" 170 | // // + 171 | // // to_string(query_K) + "-" + to_string(nodes.size()) + 172 | // // "-" + dataset + ".csv", 173 | // // i, l_bound, r_bound, r_bound - l_bound, 174 | // // search_key_range, query_K, greedy_time, gt); 175 | // groundtruth.emplace_back(gt); 176 | // query_ids.emplace_back(i); 177 | // } 178 | // } 179 | // } 180 | 181 | // runner.query_ids.swap(query_ids); 182 | // runner.query_ranges.swap(query_ranges); 183 | // runner.groundtruth.swap(groundtruth); 184 | // cout << " Done!" << endl; 185 | // } -------------------------------------------------------------------------------- /src/common/data_wrapper.cc: -------------------------------------------------------------------------------- 1 | #include "data_wrapper.h" 2 | 3 | #include "reader.h" 4 | #include "utils.h" 5 | 6 | void SynthesizeQuerys(const vector> &nodes, 7 | vector> &querys, const int query_num) { 8 | int dim = nodes.front().size(); 9 | std::default_random_engine e; 10 | std::uniform_int_distribution u(0, nodes.size() - 1); 11 | querys.clear(); 12 | querys.resize(query_num); 13 | 14 | for (unsigned n = 0; n < query_num; n++) { 15 | for (unsigned i = 0; i < dim; i++) { 16 | int select_idx = u(e); 17 | querys[n].emplace_back(nodes[select_idx][i]); 18 | } 19 | } 20 | } 21 | 22 | void DataWrapper::readData(string &dataset_path, string &query_path) { 23 | ReadDataWrapper(dataset, dataset_path, this->nodes, data_size, query_path, 24 | this->querys, query_num, this->nodes_keys); 25 | cout << "Load vecs from: " << dataset_path << endl; 26 | cout << "# of vecs: " << nodes.size() << endl; 27 | 28 | // already sort data in sorted_data 29 | // if (dataset != "wiki-image" && dataset != "yt8m") { 30 | // nodes_keys.resize(nodes.size()); 31 | // iota(nodes_keys.begin(), nodes_keys.end(), 0); 32 | // } 33 | 34 | if (querys.empty()) { 35 | cout << "Synthesizing querys..." << endl; 36 | SynthesizeQuerys(nodes, querys, query_num); 37 | } 38 | 39 | this->real_keys = false; 40 | vector index_permutation; // already sort data ahead 41 | 42 | // if (dataset == "wiki-image" || dataset == "yt8m") { 43 | // cout << "first search_key before sorting: " << nodes_keys.front() << 44 | // endl; cout << "sorting dataset: " << dataset << endl; index_permutation = 45 | // sort_permutation(nodes_keys); apply_permutation_in_place(nodes, 46 | // index_permutation); apply_permutation_in_place(nodes_keys, 47 | // index_permutation); cout << "Dimension: " << nodes.front().size() << 48 | // endl; cout << "first search_key: " << nodes_keys.front() << endl; 49 | // this->real_keys = true; 50 | // } 51 | this->data_dim = this->nodes.front().size(); 52 | } 53 | 54 | void SaveToCSVRow(const string &path, const int idx, const int l_bound, 55 | const int r_bound, const int pos_range, 56 | const int real_search_key_range, const int K_neighbor, 57 | const double &search_time, const vector >, 58 | const vector &pts) { 59 | std::ofstream file; 60 | file.open(path, std::ios_base::app); 61 | if (file) { 62 | file << idx << "," << l_bound << "," << r_bound << "," << pos_range << "," 63 | << real_search_key_range << "," << K_neighbor << "," << search_time 64 | << ","; 65 | for (auto ele : gt) { 66 | file << ele << " "; 67 | } 68 | // file << ","; 69 | // for (auto ele : pts) { 70 | // file << ele << " "; 71 | // } 72 | file << "\n"; 73 | } 74 | file.close(); 75 | } 76 | 77 | void DataWrapper::generateRangeFilteringQueriesAndGroundtruth( 78 | bool is_save_to_file, const string save_path) { 79 | std::default_random_engine e; 80 | timeval t1, t2; 81 | double accu_time = 0.0; 82 | 83 | vector query_range_list; 84 | float scale = 0.05; 85 | if (this->dataset == "local") { 86 | scale = 0.1; 87 | } 88 | int init_range = this->data_size * scale; 89 | while (init_range <= this->data_size) { 90 | query_range_list.emplace_back(init_range); 91 | init_range += this->data_size * scale; 92 | } 93 | #ifdef LOG_DEBUG_MODE 94 | if (this->dataset == "local") { 95 | query_range_list.erase( 96 | query_range_list.begin(), 97 | query_range_list.begin() + 4 * query_range_list.size() / 9); 98 | } 99 | #endif 100 | if (0.01 * this->data_size > 100) 101 | query_range_list.insert(query_range_list.begin(), 0.01 * this->data_size); 102 | 103 | if (0.001 * this->data_size > 100) 104 | query_range_list.insert(query_range_list.begin(), 0.001 * this->data_size); 105 | 106 | if (0.0001 * this->data_size > 100) 107 | query_range_list.insert(query_range_list.begin(), 0.0001 * this->data_size); 108 | if (this->data_size == 1000000) { 109 | query_range_list = {1000, 2000, 3000, 4000, 5000, 6000, 110 | 7000, 8000, 9000, 10000, 20000, 30000, 111 | 40000, 50000, 60000, 70000, 80000, 90000, 112 | 100000, 200000, 300000, 400000, 500000, 600000, 113 | 700000, 800000, 900000, 1000000}; 114 | } 115 | 116 | cout << "Generating Groundtruth...\nRanges: "; 117 | print_set(query_range_list); 118 | // generate groundtruth 119 | 120 | for (auto range : query_range_list) { 121 | std::uniform_int_distribution u_lbound(0, 122 | this->data_size - range - 80); 123 | for (int i = 0; i < this->querys.size(); i++) { 124 | int l_bound = u_lbound(e); 125 | int r_bound = l_bound + range - 1; 126 | if (range == this->data_size) { 127 | l_bound = 0; 128 | r_bound = this->data_size - 1; 129 | } 130 | int search_key_range = r_bound - l_bound + 1; 131 | if (this->real_keys) 132 | search_key_range = 133 | this->nodes_keys.at(r_bound) - this->nodes_keys.at(l_bound); 134 | query_ranges.emplace_back(std::make_pair(l_bound, r_bound)); 135 | double greedy_time; 136 | gettimeofday(&t1, NULL); 137 | auto gt = greedyNearest(this->nodes, this->querys.at(i), l_bound, r_bound, 138 | this->query_k); 139 | gettimeofday(&t2, NULL); 140 | CountTime(t1, t2, greedy_time); 141 | groundtruth.emplace_back(gt); 142 | query_ids.emplace_back(i); 143 | accu_time += greedy_time; 144 | if (is_save_to_file) { 145 | SaveToCSVRow(save_path, i, l_bound, r_bound, range, search_key_range, 146 | this->query_k, greedy_time, gt, this->querys.at(i)); 147 | } 148 | } 149 | } 150 | 151 | cout << " Done!" << endl << "Groundtruth Time: " << accu_time << endl; 152 | if (is_save_to_file) { 153 | cout << "Save GroundTruth to path: " << save_path << endl; 154 | } 155 | } 156 | 157 | // Get Groundtruth for half bounded query 158 | void DataWrapper::generateHalfBoundedQueriesAndGroundtruth( 159 | bool is_save_to_file, const string save_path) { 160 | timeval t1, t2; 161 | 162 | vector query_range_list; 163 | int init_range = this->data_size * 0.05; 164 | while (init_range <= this->data_size) { 165 | query_range_list.emplace_back(init_range); 166 | init_range += this->data_size * 0.05; 167 | } 168 | if (0.01 * this->data_size > 100) 169 | query_range_list.insert(query_range_list.begin(), 0.01 * this->data_size); 170 | 171 | if (0.001 * this->data_size > 100) 172 | query_range_list.insert(query_range_list.begin(), 0.001 * this->data_size); 173 | 174 | if (0.0001 * this->data_size > 100) 175 | query_range_list.insert(query_range_list.begin(), 0.0001 * this->data_size); 176 | 177 | if (this->data_size == 1000000) { 178 | query_range_list = {1000, 2000, 3000, 4000, 5000, 6000, 179 | 7000, 8000, 9000, 10000, 20000, 30000, 180 | 40000, 50000, 60000, 70000, 80000, 90000, 181 | 100000, 200000, 300000, 400000, 500000, 600000, 182 | 700000, 800000, 900000, 1000000}; 183 | } 184 | 185 | if (this->data_size == 100000) { 186 | query_range_list = {1000, 2000, 3000, 4000, 5000, 6000, 7000, 187 | 8000, 9000, 10000, 20000, 30000, 40000, 50000, 188 | 60000, 70000, 80000, 90000, 100000}; 189 | } 190 | 191 | cout << "Generating Half Bounded Groundtruth..."; 192 | cout << endl << "Ranges: " << endl; 193 | print_set(query_range_list); 194 | // generate groundtruth 195 | for (auto range : query_range_list) { 196 | for (int i = 0; i < this->querys.size(); i++) { 197 | int l_bound = 0; 198 | int r_bound = range - 1; 199 | 200 | int search_key_range = r_bound - l_bound + 1; 201 | if (this->real_keys) 202 | search_key_range = 203 | this->nodes_keys.at(r_bound) - this->nodes_keys.at(l_bound); 204 | query_ranges.emplace_back(std::make_pair(l_bound, r_bound)); 205 | double greedy_time; 206 | gettimeofday(&t1, NULL); 207 | auto gt = greedyNearest(this->nodes, this->querys.at(i), l_bound, r_bound, 208 | this->query_k); 209 | gettimeofday(&t2, NULL); 210 | CountTime(t1, t2, greedy_time); 211 | groundtruth.emplace_back(gt); 212 | query_ids.emplace_back(i); 213 | if (is_save_to_file) { 214 | SaveToCSVRow(save_path, i, l_bound, r_bound, range, search_key_range, 215 | this->query_k, greedy_time, gt, this->querys.at(i)); 216 | } 217 | } 218 | } 219 | cout << " Done!" << endl; 220 | if (is_save_to_file) { 221 | cout << "Save GroundTruth to path: " << save_path << endl; 222 | } 223 | } 224 | 225 | void DataWrapper::LoadGroundtruth(const string >_path) { 226 | this->groundtruth.clear(); 227 | this->query_ranges.clear(); 228 | this->query_ids.clear(); 229 | cout << "Loading Groundtruth from" << gt_path << "..."; 230 | ReadGroundtruthQuery(this->groundtruth, this->query_ranges, this->query_ids, 231 | gt_path); 232 | cout << " Done!" << endl; 233 | } 234 | 235 | void DataWrapper::generateRangeFilteringQueriesAndGroundtruthScalability( 236 | bool is_save_to_file, const string save_path) { 237 | std::default_random_engine e; 238 | timeval t1, t2; 239 | double accu_time = 0.0; 240 | 241 | vector query_range_list; 242 | float scale = 0.001; 243 | int init_range = this->data_size * scale; 244 | while (init_range < 0.01 * this->data_size) { 245 | query_range_list.emplace_back(init_range); 246 | init_range += this->data_size * scale; 247 | } 248 | scale = 0.01; 249 | init_range = this->data_size * scale; 250 | while (init_range < 0.1 * this->data_size) { 251 | query_range_list.emplace_back(init_range); 252 | init_range += this->data_size * scale; 253 | } 254 | 255 | scale = 0.1; 256 | init_range = this->data_size * scale; 257 | while (init_range < 1 * this->data_size) { 258 | query_range_list.emplace_back(init_range); 259 | init_range += this->data_size * scale; 260 | } 261 | 262 | query_range_list.emplace_back(this->data_size); 263 | 264 | cout << "Generating Groundtruth...\nRanges: "; 265 | print_set(query_range_list); 266 | cout << "sample size:" << this->nodes.size() << endl; 267 | // generate groundtruth 268 | 269 | for (auto range : query_range_list) { 270 | std::uniform_int_distribution u_lbound(0, 271 | this->data_size - range - 80); 272 | for (int i = 0; i < this->querys.size(); i++) { 273 | int l_bound = u_lbound(e); 274 | int r_bound = l_bound + range - 1; 275 | if (range == this->data_size) { 276 | l_bound = 0; 277 | r_bound = this->data_size - 1; 278 | } 279 | int search_key_range = r_bound - l_bound + 1; 280 | // if (this->real_keys) 281 | // search_key_range = 282 | // this->nodes_keys.at(r_bound) - this->nodes_keys.at(l_bound); 283 | query_ranges.emplace_back(std::make_pair(l_bound, r_bound)); 284 | double greedy_time; 285 | gettimeofday(&t1, NULL); 286 | auto gt = greedyNearest(this->nodes, this->querys.at(i), l_bound, r_bound, 287 | this->query_k); 288 | gettimeofday(&t2, NULL); 289 | CountTime(t1, t2, greedy_time); 290 | groundtruth.emplace_back(gt); 291 | query_ids.emplace_back(i); 292 | accu_time += greedy_time; 293 | if (is_save_to_file) { 294 | SaveToCSVRow(save_path, i, l_bound, r_bound, range, search_key_range, 295 | this->query_k, greedy_time, gt, this->querys.at(i)); 296 | } 297 | } 298 | } 299 | 300 | cout << " Done!" << endl << "Groundtruth Time: " << accu_time << endl; 301 | if (is_save_to_file) { 302 | cout << "Save GroundTruth to path: " << save_path << endl; 303 | } 304 | } 305 | 306 | // Get Groundtruth for half bounded query 307 | void DataWrapper::generateHalfBoundedQueriesAndGroundtruthScalability( 308 | bool is_save_to_file, const string save_path) { 309 | timeval t1, t2; 310 | 311 | vector query_range_list; 312 | float scale = 0.001; 313 | int init_range = this->data_size * scale; 314 | while (init_range < 0.01 * this->data_size) { 315 | query_range_list.emplace_back(init_range); 316 | init_range += this->data_size * scale; 317 | } 318 | scale = 0.01; 319 | init_range = this->data_size * scale; 320 | while (init_range < 0.1 * this->data_size) { 321 | query_range_list.emplace_back(init_range); 322 | init_range += this->data_size * scale; 323 | } 324 | 325 | scale = 0.1; 326 | init_range = this->data_size * scale; 327 | while (init_range < 1 * this->data_size) { 328 | query_range_list.emplace_back(init_range); 329 | init_range += this->data_size * scale; 330 | } 331 | 332 | query_range_list.emplace_back(this->data_size); 333 | 334 | cout << "Generating Half Bounded Groundtruth..."; 335 | cout << endl << "Ranges: " << endl; 336 | print_set(query_range_list); 337 | // generate groundtruth 338 | for (auto range : query_range_list) { 339 | for (int i = 0; i < this->querys.size(); i++) { 340 | int l_bound = 0; 341 | int r_bound = range - 1; 342 | 343 | int search_key_range = r_bound - l_bound + 1; 344 | if (this->real_keys) 345 | search_key_range = 346 | this->nodes_keys.at(r_bound) - this->nodes_keys.at(l_bound); 347 | query_ranges.emplace_back(std::make_pair(l_bound, r_bound)); 348 | double greedy_time; 349 | gettimeofday(&t1, NULL); 350 | auto gt = greedyNearest(this->nodes, this->querys.at(i), l_bound, r_bound, 351 | this->query_k); 352 | gettimeofday(&t2, NULL); 353 | CountTime(t1, t2, greedy_time); 354 | groundtruth.emplace_back(gt); 355 | query_ids.emplace_back(i); 356 | if (is_save_to_file) { 357 | SaveToCSVRow(save_path, i, l_bound, r_bound, range, search_key_range, 358 | this->query_k, greedy_time, gt, this->querys.at(i)); 359 | } 360 | } 361 | } 362 | cout << " Done!" << endl; 363 | if (is_save_to_file) { 364 | cout << "Save GroundTruth to path: " << save_path << endl; 365 | } 366 | } 367 | 368 | // For evaluating in Benchmark, use only 7 points in the benchmark: 0.1% 0.5% 1% 369 | // 5% 10% 50% 100% 370 | 371 | void DataWrapper::generateHalfBoundedQueriesAndGroundtruthBenchmark( 372 | bool is_save_to_file, const string save_path) { 373 | timeval t1, t2; 374 | 375 | vector query_range_list; 376 | query_range_list.emplace_back(this->data_size * 0.001); 377 | query_range_list.emplace_back(this->data_size * 0.005); 378 | query_range_list.emplace_back(this->data_size * 0.01); 379 | query_range_list.emplace_back(this->data_size * 0.05); 380 | query_range_list.emplace_back(this->data_size * 0.1); 381 | query_range_list.emplace_back(this->data_size * 0.5); 382 | query_range_list.emplace_back(this->data_size); 383 | 384 | cout << "Generating Half Bounded Groundtruth..."; 385 | cout << endl << "Ranges: " << endl; 386 | print_set(query_range_list); 387 | // generate groundtruth 388 | for (auto range : query_range_list) { 389 | for (int i = 0; i < this->querys.size(); i++) { 390 | int l_bound = 0; 391 | int r_bound = range - 1; 392 | 393 | int search_key_range = r_bound - l_bound + 1; 394 | if (this->real_keys) 395 | search_key_range = 396 | this->nodes_keys.at(r_bound) - this->nodes_keys.at(l_bound); 397 | query_ranges.emplace_back(std::make_pair(l_bound, r_bound)); 398 | double greedy_time; 399 | gettimeofday(&t1, NULL); 400 | auto gt = greedyNearest(this->nodes, this->querys.at(i), l_bound, r_bound, 401 | this->query_k); 402 | gettimeofday(&t2, NULL); 403 | CountTime(t1, t2, greedy_time); 404 | groundtruth.emplace_back(gt); 405 | query_ids.emplace_back(i); 406 | if (is_save_to_file) { 407 | SaveToCSVRow(save_path, i, l_bound, r_bound, range, search_key_range, 408 | this->query_k, greedy_time, gt, this->querys.at(i)); 409 | } 410 | } 411 | } 412 | cout << " Done!" << endl; 413 | if (is_save_to_file) { 414 | cout << "Save GroundTruth to path: " << save_path << endl; 415 | } 416 | } 417 | 418 | void DataWrapper::generateRangeFilteringQueriesAndGroundtruthBenchmark( 419 | bool is_save_to_file, const string save_path) { 420 | timeval t1, t2; 421 | 422 | vector query_range_list; 423 | query_range_list.emplace_back(this->data_size * 0.001); 424 | query_range_list.emplace_back(this->data_size * 0.005); 425 | query_range_list.emplace_back(this->data_size * 0.01); 426 | query_range_list.emplace_back(this->data_size * 0.05); 427 | query_range_list.emplace_back(this->data_size * 0.1); 428 | query_range_list.emplace_back(this->data_size * 0.5); 429 | query_range_list.emplace_back(this->data_size); 430 | 431 | cout << "Generating Range Filtering Groundtruth..."; 432 | cout << endl << "Ranges: " << endl; 433 | print_set(query_range_list); 434 | // generate groundtruth 435 | 436 | std::default_random_engine e; 437 | 438 | for (auto range : query_range_list) { 439 | std::uniform_int_distribution u_lbound(0, 440 | this->data_size - range - 80); 441 | for (int i = 0; i < this->querys.size(); i++) { 442 | int l_bound = u_lbound(e); 443 | int r_bound = l_bound + range - 1; 444 | if (range == this->data_size) { 445 | l_bound = 0; 446 | r_bound = this->data_size - 1; 447 | } 448 | int search_key_range = r_bound - l_bound + 1; 449 | // if (this->real_keys) 450 | // search_key_range = 451 | // this->nodes_keys.at(r_bound) - this->nodes_keys.at(l_bound); 452 | query_ranges.emplace_back(std::make_pair(l_bound, r_bound)); 453 | double greedy_time; 454 | gettimeofday(&t1, NULL); 455 | auto gt = greedyNearest(this->nodes, this->querys.at(i), l_bound, r_bound, 456 | this->query_k); 457 | gettimeofday(&t2, NULL); 458 | CountTime(t1, t2, greedy_time); 459 | groundtruth.emplace_back(gt); 460 | query_ids.emplace_back(i); 461 | if (is_save_to_file) { 462 | SaveToCSVRow(save_path, i, l_bound, r_bound, range, search_key_range, 463 | this->query_k, greedy_time, gt, this->querys.at(i)); 464 | } 465 | } 466 | } 467 | 468 | if (is_save_to_file) { 469 | cout << "Save GroundTruth to path: " << save_path << endl; 470 | } 471 | } 472 | -------------------------------------------------------------------------------- /src/common/logger.cc: -------------------------------------------------------------------------------- 1 | #include "logger.h" 2 | 3 | // compact hnsw 4 | void SaveToCSVRow(const string &path, const int idx, const int l_bound, 5 | const int r_bound, const int range, const int K_neighbor, 6 | const int initial_graph_size, const int index_graph_size, 7 | const string &method, const int search_ef, 8 | const double &precision, const double &appr_ratio, 9 | const double &search_time, const int data_size, 10 | const int num_comparison, const double path_time) { 11 | std::ofstream file; 12 | file.open(path, std::ios_base::app); 13 | if (file) { 14 | file << idx << "," << l_bound << "," << r_bound << "," << range << "," 15 | << K_neighbor << "," << initial_graph_size << "," << index_graph_size 16 | << "," << method << "," << search_ef << "," << precision << "," 17 | << appr_ratio << "," << search_time << "," << data_size << "," 18 | << num_comparison << "," << path_time; 19 | file << "\n"; 20 | } 21 | file.close(); 22 | } 23 | 24 | // knn-first 25 | void SaveToCSVRow(const string &path, const int idx, const int l_bound, 26 | const int r_bound, const int range, const int K_neighbor, 27 | const int initial_graph_size, const int index_graph_size, 28 | const string &method, const int search_ef, 29 | const double &precision, const double &appr_ratio, 30 | const double &search_time, const int data_size, 31 | const size_t num_search_comparison, 32 | const double out_bound_candidates, 33 | const double in_bound_candidates) { 34 | std::ofstream file; 35 | file.open(path, std::ios_base::app); 36 | if (file) { 37 | file << idx << "," << l_bound << "," << r_bound << "," << range << "," 38 | << K_neighbor << "," << initial_graph_size << "," << index_graph_size 39 | << "," << method << "," << search_ef << "," << precision << "," 40 | << appr_ratio << "," << search_time << "," << data_size << "," 41 | << num_search_comparison << "," << out_bound_candidates << "," 42 | << in_bound_candidates; 43 | file << "\n"; 44 | } 45 | file.close(); 46 | } 47 | 48 | void SaveToIndexCSVRow(const string &path, const string &version, 49 | const string &method, const int data_size, 50 | const int initial_graph_size, const int index_graph_size, 51 | const double nn_build_time, const double sort_time, 52 | const double build_time, const double memory, 53 | const int node_amount, const int window_count, 54 | const double index_size) { 55 | std::ofstream file; 56 | file.open(path, std::ios_base::app); 57 | if (file) { 58 | file << version << "," << method << "," << data_size << "," 59 | << initial_graph_size << "," << index_graph_size << "," 60 | << nn_build_time << "," << sort_time << "," << build_time << "," 61 | << memory << "," << node_amount << "," << window_count << "," 62 | << index_size; 63 | file << "\n"; 64 | } 65 | file.close(); 66 | } 67 | 68 | // For range filtering, HNSW detail 69 | void SaveToCSVRow(const string &path, const int idx, const int l_bound, 70 | const int r_bound, const int range, const int K_neighbor, 71 | const int initial_graph_size, const int index_graph_size, 72 | const string &method, const int search_ef, 73 | const double &precision, const double &appr_ratio, 74 | const double &search_time, const int data_size, 75 | vector &res, vector &dists) { 76 | std::ofstream file; 77 | file.open(path, std::ios_base::app); 78 | if (file) { 79 | file << idx << "," << l_bound << "," << r_bound << "," << range << "," 80 | << K_neighbor << "," << initial_graph_size << "," << index_graph_size 81 | << "," << method << "," << search_ef << "," << precision << "," 82 | << appr_ratio << "," << search_time << "," << data_size << ","; 83 | 84 | for (auto ele : res) { 85 | file << ele << " "; 86 | } 87 | file << ","; 88 | for (auto ele : dists) { 89 | file << ele << " "; 90 | } 91 | file << "\n"; 92 | } 93 | file.close(); 94 | } 95 | 96 | // For PQ 97 | void SaveToCSVRow(const string &path, const int idx, const int l_bound, 98 | const int r_bound, const int range, const int K_neighbor, 99 | const int M_pq, const int Ks_pq, const string &method, 100 | const double &precision, const double &appr_ratio, 101 | const double &search_time, const int data_size) { 102 | std::ofstream file; 103 | file.open(path, std::ios_base::app); 104 | if (file) { 105 | file << idx << "," << l_bound << "," << r_bound << "," << range << "," 106 | << K_neighbor << "," << M_pq << "," << Ks_pq << "," << method << "," 107 | << precision << "," << appr_ratio << "," << search_time << "," 108 | << data_size; 109 | file << "\n"; 110 | } 111 | file.close(); 112 | } -------------------------------------------------------------------------------- /src/common/reader.cc: -------------------------------------------------------------------------------- 1 | #include "reader.h" 2 | 3 | #include 4 | 5 | using std::vector; 6 | FvecsItrReader::FvecsItrReader(std::string filename) { 7 | ifs.open(filename, std::ios::binary); 8 | assert(ifs.is_open()); 9 | Next(); 10 | } 11 | 12 | bool FvecsItrReader::IsEnd() { return eof_flag; } 13 | 14 | std::vector FvecsItrReader::Next() { 15 | std::vector prev_vec = vec; // return the currently stored vec 16 | int D; 17 | if (ifs.read((char *)&D, sizeof(int))) { // read "D" 18 | // Then, read a D-dim vec 19 | vec.resize(D); // allocate D-dim 20 | ifs.read((char *)vec.data(), sizeof(float) * D); // Read D * float. 21 | eof_flag = false; 22 | } else { 23 | vec.clear(); 24 | eof_flag = true; 25 | } 26 | return prev_vec; 27 | } 28 | 29 | BvecsItrReader::BvecsItrReader(std::string filename) { 30 | ifs.open(filename, std::ios::binary); 31 | assert(ifs.is_open()); 32 | Next(); 33 | } 34 | 35 | bool BvecsItrReader::IsEnd() { return eof_flag; } 36 | 37 | std::vector BvecsItrReader::Next() { 38 | std::vector prev_vec = vec; // return the currently stored vec 39 | int D; 40 | if (ifs.read((char *)&D, sizeof(int))) { // read "D" 41 | // Then, read a D-dim vec 42 | vec.resize(D); // allocate D-dim 43 | std::vector buff(D); 44 | 45 | assert(ifs.read((char *)buff.data(), 46 | sizeof(unsigned char) * D)); // Read D * uchar. 47 | 48 | // Convert uchar to float 49 | for (int d = 0; d < D; ++d) { 50 | vec[d] = static_cast(buff[d]); 51 | } 52 | 53 | eof_flag = false; 54 | } else { 55 | vec.clear(); 56 | eof_flag = true; 57 | } 58 | return prev_vec; 59 | } 60 | 61 | ItrReader::ItrReader(std::string filename, std::string ext) { 62 | if (ext == "fvecs") { 63 | m_reader = (I_ItrReader *)new FvecsItrReader(filename); 64 | } else if (ext == "bvecs") { 65 | m_reader = (I_ItrReader *)new BvecsItrReader(filename); 66 | } else { 67 | std::cerr << "Error: strange ext type: " << ext << "in ItrReader" 68 | << std::endl; 69 | exit(1); 70 | } 71 | } 72 | 73 | ItrReader::~ItrReader() { delete m_reader; } 74 | 75 | bool ItrReader::IsEnd() { return m_reader->IsEnd(); } 76 | 77 | std::vector ItrReader::Next() { return m_reader->Next(); } 78 | 79 | std::vector> ReadTopN(std::string filename, std::string ext, 80 | int top_n) { 81 | std::vector> vecs; 82 | if (top_n != -1) { 83 | vecs.reserve(top_n); 84 | } 85 | ItrReader reader(filename, ext); 86 | while (!reader.IsEnd()) { 87 | if (top_n != -1 && top_n <= (int)vecs.size()) { 88 | break; 89 | } 90 | vecs.emplace_back(reader.Next()); 91 | } 92 | return vecs; 93 | } 94 | 95 | /// @brief Reading binary data vectors. Raw data store as a (N x 100) 96 | /// binary file. 97 | /// @param file_path file path of binary data 98 | /// @param data returned 2D data vectors 99 | /// @param N Reading top N vectors 100 | /// @param num_dimensions dimension of dataset 101 | void ReadFvecsTopN(const std::string &file_path, 102 | std::vector> &data, const uint32_t N, 103 | const int num_dimensions) { 104 | std::cout << "Reading Data: " << file_path << std::endl; 105 | std::ifstream ifs; 106 | ifs.open(file_path, std::ios::binary); 107 | assert(ifs.is_open()); 108 | 109 | data.resize(N); 110 | std::vector buff(num_dimensions); 111 | int counter = 0; 112 | while ((counter < N) && 113 | (ifs.read((char *)buff.data(), num_dimensions * sizeof(double)))) { 114 | std::vector row(num_dimensions); 115 | for (int d = 0; d < num_dimensions; d++) { 116 | row[d] = static_cast(buff[d]); 117 | } 118 | data[counter++] = std::move(row); 119 | } 120 | 121 | ifs.close(); 122 | std::cout << "Finish Reading Data" << endl; 123 | } 124 | 125 | /// @brief Reading binary data vectors. Raw data store as a (N x 100) 126 | /// binary file. Skip some nodes, for reading querys 127 | /// @param file_path file path of binary data 128 | /// @param data returned 2D data vectors 129 | /// @param N Reading top N vectors 130 | /// @param num_dimensions dimension of dataset 131 | void ReadFvecsSkipTop(const std::string &file_path, 132 | std::vector> &data, const uint32_t N, 133 | const int num_dimensions, const int skip_num) { 134 | std::cout << "Query Start From Position: " << skip_num << std::endl; 135 | std::ifstream ifs; 136 | ifs.open(file_path, std::ios::binary); 137 | assert(ifs.is_open()); 138 | ifs.seekg(num_dimensions * sizeof(double) * skip_num); 139 | 140 | data.resize(N); 141 | std::vector buff(num_dimensions); 142 | int counter = 0; 143 | while ((counter < N) && 144 | ifs.read((char *)buff.data(), num_dimensions * sizeof(double))) { 145 | std::vector row(num_dimensions); 146 | for (int d = 0; d < num_dimensions; d++) { 147 | row[d] = static_cast(buff[d]); 148 | } 149 | data[counter++] = std::move(row); 150 | } 151 | 152 | ifs.close(); 153 | } 154 | 155 | /// @brief Reading metadata information, stored in uint32_t format 156 | /// @param file_path file path of binary data 157 | /// @param data returned 2D data vectors 158 | /// @param N Reading top N vectors 159 | /// @param num_dimensions dimension of dataset 160 | void ReadIvecsTopN(const std::string &file_path, std::vector &keys, 161 | const uint32_t N, const int num_dimensions, 162 | const int position) { 163 | std::cout << "Reading Keys: " << file_path << std::endl; 164 | std::ifstream ifs; 165 | ifs.open(file_path, std::ios::binary); 166 | assert(ifs.is_open()); 167 | 168 | keys.resize(N); 169 | std::vector buff(num_dimensions); 170 | int counter = 0; 171 | while ((counter < N) && 172 | ifs.read((char *)buff.data(), num_dimensions * sizeof(uint64_t))) { 173 | keys[counter++] = static_cast(buff[position]); 174 | } 175 | 176 | ifs.close(); 177 | std::cout << "Finish Reading Keys" << endl; 178 | } 179 | 180 | void ReadDataWrapper(vector> &raw_data, vector &search_keys, 181 | const string &dataset, string &dataset_path, 182 | const int item_num) { 183 | raw_data.clear(); 184 | if (dataset == "glove") { 185 | ReadMatFromTxtTwitter(dataset_path, raw_data, item_num); 186 | } else if (dataset == "ml25m") { 187 | ReadMatFromTxt(dataset_path, raw_data, item_num); 188 | } else if (dataset == "sift") { 189 | raw_data = ReadTopN(dataset_path, "bvecs", item_num); 190 | } else if (dataset == "biggraph") { 191 | ReadMatFromTsv(dataset_path, raw_data, item_num); 192 | } else if (dataset == "local") { 193 | raw_data = ReadTopN(dataset_path, "fvecs", item_num); 194 | } else if (dataset == "deep") { 195 | raw_data = ReadTopN(dataset_path, "fvecs", item_num); 196 | } else if (dataset == "deep10m") { 197 | raw_data = ReadTopN(dataset_path, "fvecs", item_num); 198 | } else if (dataset == "yt8m") { 199 | ReadMatFromTsvYT8M(dataset_path, raw_data, search_keys, item_num); 200 | } else { 201 | std::cerr << "Wrong Datset!" << endl; 202 | assert(false); 203 | } 204 | } 205 | 206 | // load data and querys 207 | void ReadDataWrapper(const string &dataset, string &dataset_path, 208 | vector> &raw_data, const int data_size, 209 | string &query_path, vector> &querys, 210 | const int query_size, vector &search_keys) { 211 | raw_data.clear(); 212 | if (dataset == "glove" || dataset == "glove25" || dataset == "glove50" || 213 | dataset == "glove100" || dataset == "glove200") { 214 | ReadMatFromTxtTwitter(dataset_path, raw_data, data_size); 215 | } else if (dataset == "ml25m") { 216 | ReadMatFromTxt(dataset_path, raw_data, data_size); 217 | } else if (dataset == "sift") { 218 | raw_data = ReadTopN(dataset_path, "bvecs", data_size); 219 | querys = ReadTopN(query_path, "bvecs", query_size); 220 | } else if (dataset == "biggraph") { 221 | ReadMatFromTsv(dataset_path, raw_data, data_size); 222 | } else if (dataset == "local") { 223 | cout << dataset_path << endl; 224 | raw_data = ReadTopN(dataset_path, "fvecs", data_size); 225 | } else if (dataset == "deep") { 226 | raw_data = ReadTopN(dataset_path, "fvecs", data_size); 227 | querys = ReadTopN(query_path, "fvecs", query_size); 228 | } else if (dataset == "deep10m") { 229 | raw_data = ReadTopN(dataset_path, "fvecs", data_size); 230 | } else if (dataset == "yt8m") { 231 | ReadMatFromTsvYT8M(dataset_path, raw_data, search_keys, data_size); 232 | } else if (dataset == "yt8m-video") { 233 | ReadFvecsTopN(dataset_path, raw_data, data_size, 1024); 234 | ReadFvecsTopN(query_path, querys, query_size, 1024); 235 | } else if (dataset == "yt8m-audio") { 236 | ReadFvecsTopN(dataset_path, raw_data, data_size, 128); 237 | ReadFvecsTopN(query_path, querys, query_size, 128); 238 | 239 | } else if (dataset == "wiki-image") { 240 | ReadFvecsTopN(dataset_path, raw_data, data_size, 2048); 241 | ReadFvecsTopN(query_path, querys, query_size, 2048); 242 | 243 | } 244 | 245 | else { 246 | std::cerr << "Wrong Datset!" << endl; 247 | assert(false); 248 | } 249 | } 250 | 251 | void Split(std::string &s, std::string &delim, std::vector *ret) { 252 | size_t last = 0; 253 | size_t index = s.find_first_of(delim, last); 254 | while (index != std::string::npos) { 255 | ret->push_back(s.substr(last, index - last)); 256 | last = index + 1; 257 | index = s.find_first_of(delim, last); 258 | } 259 | if (index - last > 0) { 260 | ret->push_back(s.substr(last, index - last)); 261 | } 262 | } 263 | 264 | // load txt matrix data 265 | void ReadMatFromTxt(const string &path, vector> &data, 266 | const int length_limit = -1) { 267 | ifstream infile; 268 | string bline; 269 | string delim = " "; 270 | int numCols = 0; 271 | infile.open(path, ios::in); 272 | if (getline(infile, bline, '\n')) { 273 | vector ret; 274 | Split(bline, delim, &ret); 275 | numCols = ret.size(); 276 | } 277 | infile.close(); 278 | // cout << "Reading " << path << " ..." << endl; 279 | // cout << "# of columns: " << numCols << endl; 280 | 281 | int counter = 0; 282 | if (length_limit == -1) counter = -9999999; 283 | // TODO: read sparse matrix 284 | infile.open(path, ios::in); 285 | while (getline(infile, bline, '\n')) { 286 | if (counter >= length_limit) break; 287 | counter++; 288 | 289 | vector ret; 290 | Split(bline, delim, &ret); 291 | vector arow(numCols); 292 | assert(ret.size() == numCols); 293 | for (int i = 0; i < ret.size(); i++) { 294 | arow[i] = static_cast(stod(ret[i])); 295 | } 296 | data.emplace_back(arow); 297 | } 298 | infile.close(); 299 | // cout << "# of rows: " << data.size() << endl; 300 | } 301 | 302 | void ReadMatFromTxtTwitter(const string &path, vector> &data, 303 | const int length_limit = -1) { 304 | ifstream infile; 305 | string bline; 306 | string delim = " "; 307 | int numCols = 0; 308 | infile.open(path, ios::in); 309 | if (getline(infile, bline, '\n')) { 310 | vector ret; 311 | Split(bline, delim, &ret); 312 | numCols = ret.size() - 1; 313 | } 314 | infile.close(); 315 | cout << "Reading " << path << " ..." << endl; 316 | 317 | cout << "# of columns: " << numCols << endl; 318 | 319 | int counter = 0; 320 | if (length_limit == -1) counter = -9999999; 321 | // TODO: read sparse matrix 322 | infile.open(path, ios::in); 323 | while (getline(infile, bline, '\n')) { 324 | if (counter >= length_limit) break; 325 | counter++; 326 | 327 | vector ret; 328 | Split(bline, delim, &ret); 329 | vector arow(numCols); 330 | assert(ret.size() == numCols + 1); 331 | for (int i = 1; i < ret.size(); i++) { 332 | arow[i - 1] = static_cast(stod(ret[i])); 333 | } 334 | data.emplace_back(arow); 335 | } 336 | infile.close(); 337 | cout << "# of rows: " << data.size() << endl; 338 | } 339 | 340 | void ReadMatFromTsv(const string &path, vector> &data, 341 | const int length_limit = -1) { 342 | ifstream infile; 343 | string bline; 344 | string delim = "\t"; 345 | int numCols = 0; 346 | infile.open(path, ios::in); 347 | getline(infile, bline, '\n'); 348 | if (getline(infile, bline, '\n')) { 349 | vector ret; 350 | Split(bline, delim, &ret); 351 | numCols = ret.size(); 352 | } 353 | infile.close(); 354 | cout << "Reading " << path << " ..." << endl; 355 | cout << "# of columns: " << numCols << endl; 356 | 357 | int counter = 0; 358 | if (length_limit == -1) counter = -9999999; 359 | infile.open(path, ios::in); 360 | // skip the first line 361 | getline(infile, bline, '\n'); 362 | while (getline(infile, bline, '\n')) { 363 | if (counter >= length_limit) break; 364 | counter++; 365 | 366 | vector ret; 367 | Split(bline, delim, &ret); 368 | vector arow(numCols - 1); 369 | assert(ret.size() == numCols); 370 | for (int i = 0; i < ret.size() - 1; i++) { 371 | arow[i] = static_cast(stod(ret[i + 1])); 372 | } 373 | data.emplace_back(arow); 374 | } 375 | infile.close(); 376 | cout << "# of rows: " << data.size() << endl; 377 | } 378 | 379 | int YT8M2Int(const string id) { 380 | int res = 0; 381 | for (size_t i = 0; i < 4; i++) { 382 | res *= 100; 383 | res += (int)id[i] - 38; 384 | } 385 | return res; 386 | } 387 | 388 | void ReadMatFromTsvYT8M(const string &path, vector> &data, 389 | vector &search_keys, const int length_limit) { 390 | ifstream infile; 391 | string bline; 392 | string delim = ","; 393 | int numCols = 0; 394 | infile.open(path, ios::in); 395 | getline(infile, bline, '\n'); 396 | if (getline(infile, bline, '\n')) { 397 | vector ret; 398 | Split(bline, delim, &ret); 399 | numCols = ret.size(); 400 | } 401 | infile.close(); 402 | cout << "Reading " << path << " ..." << endl; 403 | cout << "# of columns: " << numCols << endl; 404 | 405 | int counter = 0; 406 | if (length_limit == -1) counter = -9999999; 407 | infile.open(path, ios::in); 408 | string delim_embed = " "; 409 | 410 | while (getline(infile, bline, '\n')) { 411 | if (counter >= length_limit) break; 412 | counter++; 413 | 414 | vector ret; 415 | Split(bline, delim, &ret); 416 | assert(ret.size() == numCols); 417 | 418 | // str 'id' to int 'id' 419 | // int one_search_key = YT8M2Int(ret[0]); 420 | int one_search_key = (int)stod(ret[1]); 421 | 422 | // add embedding 423 | string embedding_str = ret[2]; 424 | vector embedding_vec; 425 | vector arow(1024); 426 | Split(embedding_str, delim_embed, &embedding_vec); 427 | assert(embedding_vec.size() == 1024); 428 | for (int i = 0; i < embedding_vec.size() - 1; i++) { 429 | arow[i] = static_cast(stod(embedding_vec[i + 1])); 430 | } 431 | search_keys.emplace_back(one_search_key); 432 | data.emplace_back(arow); 433 | } 434 | infile.close(); 435 | cout << "# of rows: " << data.size() << endl; 436 | } 437 | 438 | void ReadMatFromTsvYT8M(const string &path, vector> &data, 439 | const int length_limit) { 440 | ifstream infile; 441 | string bline; 442 | string delim = ","; 443 | int numCols = 0; 444 | infile.open(path, ios::in); 445 | getline(infile, bline, '\n'); 446 | if (getline(infile, bline, '\n')) { 447 | vector ret; 448 | Split(bline, delim, &ret); 449 | numCols = ret.size(); 450 | } 451 | infile.close(); 452 | cout << "Reading " << path << " ..." << endl; 453 | cout << "# of columns: " << numCols << endl; 454 | 455 | int counter = 0; 456 | if (length_limit == -1) counter = -9999999; 457 | infile.open(path, ios::in); 458 | string delim_embed = " "; 459 | 460 | while (getline(infile, bline, '\n')) { 461 | if (counter >= length_limit) break; 462 | counter++; 463 | 464 | vector ret; 465 | Split(bline, delim, &ret); 466 | assert(ret.size() == numCols); 467 | 468 | // add embedding 469 | string embedding_str = ret[2]; 470 | vector embedding_vec; 471 | vector arow(1024); 472 | Split(embedding_str, delim_embed, &embedding_vec); 473 | assert(embedding_vec.size() == 1024); 474 | for (int i = 0; i < embedding_vec.size() - 1; i++) { 475 | arow[i] = static_cast(stod(embedding_vec[i + 1])); 476 | } 477 | data.emplace_back(arow); 478 | } 479 | infile.close(); 480 | cout << "# of rows: " << data.size() << endl; 481 | } 482 | 483 | void ReadGroundtruthQuery(vector> >, 484 | vector> &query_ranges, 485 | vector &query_ids, string gt_path) { 486 | ifstream infile; 487 | string bline; 488 | string delim = ","; 489 | string space_delim = " "; 490 | 491 | int numCols = 0; 492 | infile.open(gt_path, ios::in); 493 | assert(infile.is_open()); 494 | 495 | int counter = 0; 496 | while (getline(infile, bline, '\n')) { 497 | counter++; 498 | vector one_gt; 499 | std::pair one_range; 500 | int one_id; 501 | vector ret; 502 | Split(bline, delim, &ret); 503 | one_id = std::stoi(ret[0]); 504 | one_range.first = std::stoi(ret[1]); 505 | one_range.second = std::stoi(ret[2]); 506 | vector str_gt; 507 | Split(ret[7], space_delim, &str_gt); 508 | str_gt.pop_back(); 509 | for (auto ele : str_gt) { 510 | one_gt.emplace_back(std::stoi(ele)); 511 | } 512 | gt.emplace_back(one_gt); 513 | query_ranges.emplace_back(one_range); 514 | query_ids.emplace_back(one_id); 515 | } 516 | } 517 | 518 | void fvecs2csv(const string &output_path, const vector> &nodes) { 519 | std::ofstream file; 520 | file.open(output_path, std::ios_base::app); 521 | for (auto row : nodes) { 522 | if (file) { 523 | for (auto ele : row) { 524 | file << ele << " "; 525 | } 526 | file << "\n"; 527 | } 528 | } 529 | file.close(); 530 | } 531 | 532 | // load data and querys 533 | void ReadDataWrapper(const string &dataset, string &dataset_path, 534 | vector> &raw_data, const int data_size) { 535 | raw_data.clear(); 536 | if (dataset == "glove" || dataset == "glove25" || dataset == "glove50" || 537 | dataset == "glove100" || dataset == "glove200") { 538 | ReadMatFromTxtTwitter(dataset_path, raw_data, data_size); 539 | } else if (dataset == "ml25m") { 540 | ReadMatFromTxt(dataset_path, raw_data, data_size); 541 | } else if (dataset == "sift") { 542 | raw_data = ReadTopN(dataset_path, "bvecs", data_size); 543 | } else if (dataset == "biggraph") { 544 | ReadMatFromTsv(dataset_path, raw_data, data_size); 545 | } else if (dataset == "local") { 546 | raw_data = ReadTopN(dataset_path, "fvecs", data_size); 547 | } else if (dataset == "deep") { 548 | raw_data = ReadTopN(dataset_path, "fvecs", data_size); 549 | } else if (dataset == "deep10m") { 550 | raw_data = ReadTopN(dataset_path, "fvecs", data_size); 551 | } else if (dataset == "yt8m") { 552 | ReadMatFromTsvYT8M(dataset_path, raw_data, data_size); 553 | } else { 554 | std::cerr << "Wrong Datset!" << endl; 555 | assert(false); 556 | } 557 | } -------------------------------------------------------------------------------- /src/common/utils.cc: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | // l2 norm 4 | float EuclideanDistance(const vector &lhs, const vector &rhs, 5 | const int &startDim, int lensDim) { 6 | float ans = 0.0; 7 | if (lensDim == 0) { 8 | lensDim = lhs.size(); 9 | } 10 | 11 | for (int i = startDim; i < startDim + lensDim; ++i) { 12 | ans += (lhs[i] - rhs[i]) * (lhs[i] - rhs[i]); 13 | } 14 | return ans; 15 | } 16 | 17 | // l2sqr 18 | float EuclideanDistanceSquare(const vector &lhs, 19 | const vector &rhs) { 20 | float ans = 0.0; 21 | 22 | for (int i = 0; i < lhs.size(); ++i) { 23 | ans += (lhs[i] - rhs[i]) * (lhs[i] - rhs[i]); 24 | } 25 | return ans; 26 | } 27 | 28 | void testUTIL2() { cout << "hello" << endl; } 29 | 30 | float EuclideanDistance(const vector &lhs, const vector &rhs) { 31 | return EuclideanDistance(lhs, rhs, 0, 0); 32 | } 33 | 34 | // t1:begin, t2:end 35 | void AccumulateTime(timeval &t1, timeval &t2, double &val_time) { 36 | val_time += (t2.tv_sec - t1.tv_sec + 37 | (t2.tv_usec - t1.tv_usec) * 1.0 / CLOCKS_PER_SEC); 38 | } 39 | 40 | void CountTime(timeval &t1, timeval &t2, double &val_time) { 41 | val_time = 0; 42 | val_time += (t2.tv_sec - t1.tv_sec + 43 | (t2.tv_usec - t1.tv_usec) * 1.0 / CLOCKS_PER_SEC); 44 | } 45 | 46 | double CountTime(timeval &t1, timeval &t2) { 47 | double val_time = 0.0; 48 | val_time += (t2.tv_sec - t1.tv_sec + 49 | (t2.tv_usec - t1.tv_usec) * 1.0 / CLOCKS_PER_SEC); 50 | return val_time; 51 | } 52 | 53 | void logTime(timeval &begin, timeval &end, const string &log) { 54 | gettimeofday(&end, NULL); 55 | fprintf(stdout, ("# " + log + ": %.7fs\n").c_str(), 56 | end.tv_sec - begin.tv_sec + 57 | (end.tv_usec - begin.tv_usec) * 1.0 / CLOCKS_PER_SEC); 58 | }; 59 | 60 | double countPrecision(const vector &truth, const vector &pred) { 61 | double num_right = 0; 62 | for (auto one : truth) { 63 | if (find(pred.begin(), pred.end(), one) != pred.end()) { 64 | num_right += 1; 65 | } 66 | } 67 | return num_right / truth.size(); 68 | } 69 | 70 | double countApproximationRatio(const vector> &raw_data, 71 | const vector &truth, 72 | const vector &pred, 73 | const vector &query) { 74 | if (pred.size() == 0) { 75 | return 0; 76 | } 77 | vector truth_dist; 78 | vector pred_dist; 79 | for (auto vec : truth) { 80 | truth_dist.emplace_back(EuclideanDistance(query, raw_data[vec])); 81 | } 82 | for (auto vec : pred) { 83 | if (vec == -1) continue; 84 | pred_dist.emplace_back(EuclideanDistance(query, raw_data[vec])); 85 | } 86 | if (pred_dist.size() == 0) { 87 | return 0; 88 | } 89 | auto max_truth = *max_element(truth_dist.begin(), truth_dist.end()); 90 | auto max_pred = *max_element(pred_dist.begin(), pred_dist.end()); 91 | if (pred.size() < truth.size()) { 92 | nth_element(truth_dist.begin(), truth_dist.begin() + pred.size() - 1, 93 | truth_dist.end()); 94 | max_truth = truth_dist[pred.size() - 1]; 95 | } 96 | if (max_truth != 0) return max_pred / max_truth; 97 | // cout << "ERROR: empty pred!" << endl; 98 | return -1; 99 | } 100 | 101 | void print_memory() { 102 | #ifdef __linux__ 103 | struct sysinfo memInfo; 104 | 105 | sysinfo(&memInfo); 106 | // long long totalVirtualMem = memInfo.totalram; 107 | // // Add other values in next statement to avoid int overflow on right hand 108 | // // side... 109 | // totalVirtualMem += memInfo.totalswap; 110 | // totalVirtualMem *= memInfo.mem_unit; 111 | 112 | // long long virtualMemUsed = memInfo.totalram - memInfo.freeram; 113 | // // Add other values in next statement to avoid int overflow on right hand 114 | // // side... 115 | // virtualMemUsed += memInfo.totalswap - memInfo.freeswap; 116 | // virtualMemUsed *= memInfo.mem_unit; 117 | // cout << "Total Virtual Memory: " << totalVirtualMem << endl; 118 | // cout << "Used Virtual Memory: " << virtualMemUsed << endl; 119 | 120 | long long totalPhysMem = memInfo.totalram; 121 | // Multiply in next statement to avoid int overflow on right hand side... 122 | totalPhysMem *= memInfo.mem_unit; 123 | 124 | long long physMemUsed = memInfo.totalram - memInfo.freeram; 125 | // Multiply in next statement to avoid int overflow on right hand side... 126 | physMemUsed *= memInfo.mem_unit; 127 | 128 | // cout << "Total Physical Memory: " << totalPhysMem << endl; 129 | cout << "Used Physical Memory: " << physMemUsed << endl; 130 | #elif __APPLE__ 131 | vm_size_t page_size; 132 | mach_port_t mach_port; 133 | mach_msg_type_number_t count; 134 | vm_statistics64_data_t vm_stats; 135 | 136 | mach_port = mach_host_self(); 137 | count = sizeof(vm_stats) / sizeof(natural_t); 138 | if (KERN_SUCCESS == host_page_size(mach_port, &page_size) && 139 | KERN_SUCCESS == host_statistics64(mach_port, HOST_VM_INFO, 140 | (host_info64_t)&vm_stats, &count)) { 141 | long long free_memory = (int64_t)vm_stats.free_count * (int64_t)page_size; 142 | 143 | long long used_memory = 144 | ((int64_t)vm_stats.active_count + (int64_t)vm_stats.inactive_count + 145 | (int64_t)vm_stats.wire_count) * 146 | (int64_t)page_size; 147 | printf("free memory: %lld\nused memory: %lld\n", free_memory, used_memory); 148 | } 149 | #endif 150 | } 151 | 152 | void record_memory(long long &memory) { 153 | #ifdef __linux__ 154 | struct sysinfo memInfo; 155 | sysinfo(&memInfo); 156 | long long physMemUsed = memInfo.totalram - memInfo.freeram; 157 | physMemUsed *= memInfo.mem_unit; 158 | memory = physMemUsed; 159 | #elif __APPLE__ 160 | vm_size_t page_size; 161 | mach_port_t mach_port; 162 | mach_msg_type_number_t count; 163 | vm_statistics64_data_t vm_stats; 164 | 165 | mach_port = mach_host_self(); 166 | count = sizeof(vm_stats) / sizeof(natural_t); 167 | if (KERN_SUCCESS == host_page_size(mach_port, &page_size) && 168 | KERN_SUCCESS == host_statistics64(mach_port, HOST_VM_INFO, 169 | (host_info64_t)&vm_stats, &count)) { 170 | memory = ((int64_t)vm_stats.active_count + 171 | (int64_t)vm_stats.inactive_count + (int64_t)vm_stats.wire_count) * 172 | (int64_t)page_size; 173 | } 174 | #endif 175 | } 176 | 177 | vector greedyNearest(const vector> &dpts, 178 | const vector query, const int k_smallest) { 179 | std::priority_queue> top_candidates; 180 | float lower_bound = _INT_MAX; 181 | for (size_t i = 0; i < dpts.size(); i++) { 182 | float dist = EuclideanDistanceSquare(query, dpts[i]); 183 | if (top_candidates.size() < k_smallest || dist < lower_bound) { 184 | top_candidates.push(std::make_pair(dist, i)); 185 | if (top_candidates.size() > k_smallest) { 186 | top_candidates.pop(); 187 | } 188 | 189 | lower_bound = top_candidates.top().first; 190 | } 191 | } 192 | vector res; 193 | while (!top_candidates.empty()) { 194 | res.emplace_back(top_candidates.top().second); 195 | top_candidates.pop(); 196 | } 197 | std::reverse(res.begin(), res.end()); 198 | return res; 199 | } 200 | 201 | // void evaluateKNNG(const vector> >, 202 | // const vector> &knng, const int K, double 203 | // &recall, double &precision) { 204 | // assert(gt.size() == knng.size()); 205 | 206 | // double all_right = 0; 207 | // int knng_amount = 0; 208 | 209 | // #pragma omp parallel for reduction(+ : all_right) reduction(+ : knng_amount) 210 | // for (unsigned n = 0; n < gt.size(); n++) { 211 | // double num_right = 0; 212 | // // skip first, itself 213 | // for (unsigned i = 1; i < K + 1; i++) { 214 | // int one = gt[n][i]; 215 | // if (find(knng[n].begin(), knng[n].end(), one) != knng[n].end()) { 216 | // num_right += 1; 217 | // } 218 | // } 219 | // all_right += num_right; 220 | // knng_amount += knng[n].size(); 221 | // } 222 | // recall = (double)all_right / (K * gt.size()); 223 | // precision = (double)all_right / (float)knng_amount; 224 | // } 225 | void greedyNearest(const int query_pos, const vector> &dpts, 226 | const int k_smallest, const int l_bound, const int r_bound) { 227 | vector dist_arr; 228 | for (size_t i = l_bound; i <= r_bound; i++) { 229 | dist_arr.emplace_back(EuclideanDistance(dpts[query_pos], dpts[i], 0, 0)); 230 | } 231 | vector sorted_idxes = sort_indexes(dist_arr); 232 | 233 | // skip the point itself 234 | if (sorted_idxes[0] == query_pos) { 235 | sorted_idxes.erase(sorted_idxes.begin()); 236 | } 237 | sorted_idxes.resize(k_smallest); 238 | // print_set(sorted_idxes); 239 | } 240 | 241 | void rangeGreedy(const vector> &nodes, const int k_smallest, 242 | const int l_bound, const int r_bound) { 243 | for (size_t i = l_bound; i <= r_bound; i++) { 244 | greedyNearest(i, nodes, k_smallest, l_bound, r_bound); 245 | } 246 | } 247 | 248 | vector greedyNearest(const vector> &dpts, 249 | const vector query, const int l_bound, 250 | const int r_bound, const int k_smallest) { 251 | std::priority_queue> top_candidates; 252 | float lower_bound = _INT_MAX; 253 | for (size_t i = l_bound; i <= r_bound; i++) { 254 | float dist = EuclideanDistance(query, dpts[i]); 255 | if (top_candidates.size() < k_smallest || dist < lower_bound) { 256 | top_candidates.push(std::make_pair(dist, i)); 257 | if (top_candidates.size() > k_smallest) { 258 | top_candidates.pop(); 259 | } 260 | 261 | lower_bound = top_candidates.top().first; 262 | } 263 | } 264 | vector res; 265 | while (!top_candidates.empty()) { 266 | res.emplace_back(top_candidates.top().second); 267 | top_candidates.pop(); 268 | } 269 | return res; 270 | } 271 | 272 | // Basic HNSW heuristic Pruning function 273 | void heuristicPrune(const vector> &nodes, 274 | vector> &top_candidates, const size_t M) { 275 | if (top_candidates.size() < M) { 276 | return; 277 | } 278 | 279 | std::priority_queue> queue_closest; 280 | std::vector> return_list; 281 | while (top_candidates.size() > 0) { 282 | queue_closest.emplace(-top_candidates.front().second, 283 | top_candidates.front().first); 284 | top_candidates.erase(top_candidates.begin()); 285 | } 286 | 287 | while (queue_closest.size()) { 288 | if (return_list.size() >= M) break; 289 | std::pair curent_pair = queue_closest.top(); 290 | float dist_to_query = -curent_pair.first; 291 | queue_closest.pop(); 292 | bool good = true; 293 | 294 | for (std::pair second_pair : return_list) { 295 | float curdist = EuclideanDistance(nodes.at(second_pair.second), 296 | nodes.at(curent_pair.second)); 297 | if (curdist < dist_to_query) { 298 | good = false; 299 | break; 300 | } 301 | } 302 | if (good) { 303 | return_list.push_back(curent_pair); 304 | } 305 | } 306 | 307 | for (std::pair curent_pair : return_list) { 308 | top_candidates.emplace_back( 309 | make_pair(curent_pair.second, -curent_pair.first)); 310 | } 311 | } 312 | 313 | vector str2vec(const string str) { 314 | std::vector vect; 315 | std::stringstream ss(str); 316 | for (int i; ss >> i;) { 317 | vect.push_back(i); 318 | if (ss.peek() == ',') ss.ignore(); 319 | } 320 | return vect; 321 | } -------------------------------------------------------------------------------- /src/index_base.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file index_base.h 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Base class for builiding segment graph index, containing virtual 5 | * function about indexing and searching. 6 | * @date Revised: 2024-01-11 7 | * 8 | * @copyright Copyright (c) 2024 9 | */ 10 | #pragma once 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "base_hnsw/space_l2.h" 19 | #include "data_wrapper.h" 20 | #include "utils.h" 21 | 22 | using std::cout; 23 | using std::endl; 24 | using std::vector; 25 | 26 | // TODO: manage default parameter in the same place 27 | static unsigned const default_K = 16; 28 | static unsigned const default_ef_construction = 400; 29 | 30 | class BaseIndex { 31 | public: 32 | BaseIndex(const DataWrapper* data) { data_wrapper = data; } 33 | 34 | int num_search_comparison; 35 | int k_graph_out_bound; 36 | bool isLog = true; 37 | 38 | // Indexing parameters 39 | struct IndexParams { 40 | // original params in hnsw 41 | unsigned K; // out degree boundry 42 | unsigned ef_construction = 400; 43 | unsigned random_seed = 100; 44 | unsigned ef_large_for_pruning = 400; // TODO: Depratched parameter 45 | unsigned ef_max = 2000; 46 | unsigned ef_construction_2d_max; // Replace ef_max 47 | bool print_one_batch = false; 48 | 49 | IndexParams(unsigned K, unsigned ef_construction, 50 | unsigned ef_large_for_pruning, unsigned ef_max) 51 | : K(K), 52 | ef_construction(ef_construction), 53 | ef_large_for_pruning(ef_large_for_pruning), 54 | ef_max(ef_max){}; 55 | 56 | // which position to cut during the recursion 57 | enum Recursion_Type_t { MIN_POS, MID_POS, MAX_POS, SMALL_LEFT_POS }; 58 | Recursion_Type_t recursion_type = Recursion_Type_t::MAX_POS; 59 | IndexParams() 60 | : K(default_K), 61 | ef_construction(default_ef_construction), 62 | random_seed(2023) {} 63 | }; 64 | 65 | struct IndexInfo { 66 | string index_version_type; 67 | double index_time; 68 | int window_count; 69 | int nodes_amount; 70 | float avg_forward_nns; 71 | float avg_reverse_nns; 72 | }; 73 | 74 | struct SearchParams { 75 | unsigned query_K; 76 | unsigned search_ef; 77 | unsigned query_range; 78 | float control_batch_threshold = 1; 79 | }; 80 | 81 | struct SearchInfo { 82 | SearchInfo(const DataWrapper* data, 83 | const BaseIndex::IndexParams* index_params, const string& meth, 84 | const string& ver) { 85 | data_wrapper = data; 86 | index = index_params; 87 | version = ver; 88 | method = meth; 89 | path_counter = 0; 90 | Path(ver + "-" + data->version); 91 | }; 92 | 93 | const DataWrapper* data_wrapper; 94 | const BaseIndex::IndexParams* index; 95 | string version; 96 | string method; 97 | 98 | int index_k; 99 | 100 | double time; 101 | double precision; 102 | double approximate_ratio; 103 | int query_id; 104 | double internal_search_time; // one query time 105 | double fetch_nns_time = 0; 106 | double cal_dist_time = 0; 107 | double other_process_time = 0; 108 | // double one_query_time; 109 | size_t total_comparison = 0; 110 | // size_t visited_num; 111 | size_t path_counter; 112 | string investigate_path; 113 | string save_path; 114 | 115 | bool is_investigate = false; 116 | 117 | void Path(const string& ver) { 118 | version = ver; 119 | save_path = "../exp/search/" + version + "-" + method + "-" + 120 | data_wrapper->dataset + "-" + 121 | std::to_string(data_wrapper->data_size) + ".csv"; 122 | 123 | std::cout << "Save result to :" << save_path << std::endl; 124 | }; 125 | 126 | void RecordOneQuery(BaseIndex::SearchParams* search) { 127 | std::ofstream file; 128 | file.open(save_path, std::ios_base::app); 129 | if (file) { 130 | file << 131 | // version << "," << method << "," << 132 | internal_search_time << "," << precision << "," << approximate_ratio 133 | << "," << search->query_range << "," << search->search_ef << "," 134 | << fetch_nns_time << "," << cal_dist_time << "," 135 | << total_comparison << "," << std::to_string(index->recursion_type) 136 | << "," << index->K << "," << index->ef_max << "," 137 | << index->ef_large_for_pruning << "," << index->ef_construction; 138 | file << "\n"; 139 | } 140 | file.close(); 141 | } 142 | }; 143 | 144 | const DataWrapper* data_wrapper; 145 | SearchInfo* search_info; 146 | 147 | virtual void buildIndex(const IndexParams* index_params) = 0; 148 | virtual vector rangeFilteringSearchInRange( 149 | const SearchParams* search_params, SearchInfo* search_info, 150 | const vector& query, const std::pair query_bound) = 0; 151 | virtual vector rangeFilteringSearchOutBound( 152 | const SearchParams* search_params, SearchInfo* search_info, 153 | const vector& query, const std::pair query_bound) = 0; 154 | virtual ~BaseIndex() {} 155 | }; 156 | -------------------------------------------------------------------------------- /src/range_index_base.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Index Sorted Range KNN 3 | * 4 | * Author: Chaoji Zuo 5 | * Email: chaoji.zuo@rutgers.edu 6 | * Date: Oct 1, 2021 7 | */ 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "data_wrapper.h" 17 | #include "delta_base_hnsw/space_l2.h" 18 | #include "utils.h" 19 | 20 | using std::cout; 21 | using std::endl; 22 | using std::vector; 23 | 24 | static unsigned const default_K = 16; 25 | static unsigned const default_ef_construction = 400; 26 | 27 | class BaseIndex { 28 | public: 29 | BaseIndex(const DataWrapper* data) { data_wrapper = data; } 30 | 31 | int num_search_comparison; 32 | int k_graph_out_bound; 33 | bool isLog = true; 34 | 35 | // Indexing parameters 36 | struct IndexParams { 37 | // original params in hnsw 38 | unsigned K; // out degree boundry 39 | unsigned ef_construction = 400; 40 | unsigned random_seed = 100; 41 | unsigned ef_large_for_pruning = 400; 42 | unsigned ef_max = 2000; 43 | bool print_one_batch = false; 44 | // which position to cut during the recursion 45 | enum Recursion_Type_t { MIN_POS, MID_POS, MAX_POS, SMALL_LEFT_POS }; 46 | Recursion_Type_t recursion_type = Recursion_Type_t::MAX_POS; 47 | IndexParams() 48 | : K(default_K), 49 | ef_construction(default_ef_construction), 50 | random_seed(2023) {} 51 | }; 52 | 53 | struct IndexInfo { 54 | string index_version_type; 55 | double index_time; 56 | int window_count; 57 | int nodes_amount; 58 | float avg_forward_nns; 59 | float avg_reverse_nns; 60 | }; 61 | 62 | struct SearchParams { 63 | // vector weights; 64 | // vector> weights; 65 | 66 | // unsigned internal_search_K; 67 | unsigned query_K; 68 | unsigned search_ef; 69 | unsigned query_range; 70 | // bool search_type; 71 | float control_batch_threshold = 1; 72 | }; 73 | 74 | struct SearchInfo { 75 | SearchInfo(const DataWrapper* data, 76 | const BaseIndex::IndexParams* index_params, const string& meth, 77 | const string& ver) { 78 | data_wrapper = data; 79 | index = index_params; 80 | version = ver; 81 | method = meth; 82 | path_counter = 0; 83 | // investigate_path = "../exp_path/" + method + "-" + version + "-" + 84 | // dataset + 85 | // "-K" + std::to_string(index_k) + "-ef" + 86 | // std::to_string(search_ef) + "-path-" + 87 | // std::to_string(path_counter) + 88 | // "-invetigate-path.csv"; 89 | Path(ver + "-" + data->version); 90 | }; 91 | 92 | const DataWrapper* data_wrapper; 93 | const BaseIndex::IndexParams* index; 94 | string version; 95 | string method; 96 | 97 | int index_k; 98 | 99 | double time; 100 | double precision; 101 | double approximate_ratio; 102 | int query_id; 103 | double internal_search_time; // one query time 104 | double fetch_nns_time = 0; 105 | double cal_dist_time = 0; 106 | double other_process_time = 0; 107 | // double one_query_time; 108 | size_t total_comparison = 0; 109 | // size_t visited_num; 110 | size_t path_counter; 111 | string investigate_path; 112 | string save_path; 113 | 114 | bool is_investigate = false; 115 | 116 | // void NewPath(const int k, const int ef) { 117 | // index_k = k; 118 | // search_ef = ef; 119 | // investigate_path = "../exp_path/" + method + "-" + version + "-" + 120 | // dataset + 121 | // "-K" + std::to_string(index_k) + "-ef" + 122 | // std::to_string(search_ef) + "-path-" + 123 | // std::to_string(path_counter) + 124 | // "-invetigate-path.csv"; 125 | // path_counter++; 126 | // } 127 | 128 | void Path(const string& ver) { 129 | version = ver; 130 | save_path = "../exp/search/" + version + "-" + method + "-" + 131 | data_wrapper->dataset + "-" + 132 | std::to_string(data_wrapper->data_size) + ".csv"; 133 | 134 | std::cout << "Save result to :" << save_path << std::endl; 135 | }; 136 | 137 | void RecordOneQuery(BaseIndex::SearchParams* search) { 138 | std::ofstream file; 139 | file.open(save_path, std::ios_base::app); 140 | if (file) { 141 | file << 142 | // version << "," << method << "," << 143 | internal_search_time << "," << precision << "," << approximate_ratio 144 | << "," << search->query_range << "," << search->search_ef << "," 145 | << fetch_nns_time << "," << cal_dist_time << "," 146 | << total_comparison << "," << std::to_string(index->recursion_type) 147 | << "," << index->K << "," << index->ef_max << "," 148 | << index->ef_large_for_pruning << "," << index->ef_construction; 149 | file << "\n"; 150 | } 151 | file.close(); 152 | } 153 | 154 | // void SavePathInvestigate(const float v1, const float v2, const float v3, 155 | // const float v4) { 156 | // std::ofstream file; 157 | // file.open(investigate_path, std::ios_base::app); 158 | // if (file) { 159 | // file << v1 << "," << v2 << "," << v3 << "," << v4; 160 | // file << "\n"; 161 | // } 162 | // file.close(); 163 | // }; 164 | }; 165 | 166 | const DataWrapper* data_wrapper; 167 | SearchInfo* search_info; 168 | 169 | virtual void buildIndex(const IndexParams* index_params) = 0; 170 | virtual vector rangeFilteringSearchInRange( 171 | const SearchParams* search_params, SearchInfo* search_info, 172 | const vector& query, const std::pair query_bound) = 0; 173 | virtual vector rangeFilteringSearchOutBound( 174 | const SearchParams* search_params, SearchInfo* search_info, 175 | const vector& query, const std::pair query_bound) = 0; 176 | virtual ~BaseIndex() {} 177 | }; 178 | -------------------------------------------------------------------------------- /src/segment_graph_1d.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file segment_graph_1d.h 3 | * @author Chaoji Zuo (chaoji.zuo@rutgers.edu) 4 | * @brief Index for half-bounded range filtering search. 5 | * Lossless compression on N hnsw on search space 6 | * @date 2023-06-29; Revised 2023-12-29 7 | * 8 | * @copyright Copyright (c) 2023 9 | */ 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "base_hnsw/hnswalg.h" 20 | #include "base_hnsw/hnswlib.h" 21 | #include "data_wrapper.h" 22 | #include "index_base.h" 23 | #include "utils.h" 24 | 25 | using namespace base_hnsw; 26 | // #define INT_MAX __INT_MAX__ 27 | 28 | namespace SeRF { 29 | 30 | /** 31 | * @brief segment neighbor structure to store segment graph edge information 32 | * if id==end_id, means haven't got pruned 33 | * id: neighbor id 34 | * dist: neighbor dist 35 | * end_id: when got pruned 36 | */ 37 | template 38 | struct SegmentNeighbor1D { 39 | SegmentNeighbor1D(int id) : id(id) {}; 40 | SegmentNeighbor1D(int id, dist_t dist, int end_id) 41 | : id(id), dist(dist), end_id(end_id) {}; 42 | int id; 43 | dist_t dist; 44 | int end_id; 45 | }; 46 | 47 | // Inherit from basic HNSW, modify the 'heuristic pruning' procedure to record 48 | // the lifecycle for SegmentGraph 49 | template 50 | class SegmentGraph1DHNSW : public HierarchicalNSW { 51 | public: 52 | SegmentGraph1DHNSW(const BaseIndex::IndexParams &index_params, 53 | SpaceInterface *s, size_t max_elements, 54 | string dataset_name_, size_t M = 16, 55 | size_t ef_construction = 200, size_t random_seed = 100) 56 | : HierarchicalNSW(s, max_elements, M, index_params.ef_construction, 57 | random_seed) { 58 | params = &index_params; 59 | // in ons-side segment graph, ef_max_ equal to ef_construction 60 | ef_max_ = index_params.ef_construction; 61 | ef_basic_construction_ = index_params.ef_construction; 62 | ef_construction = index_params.ef_construction; 63 | dataset_name = dataset_name_; 64 | } 65 | 66 | const BaseIndex::IndexParams *params; 67 | string dataset_name; 68 | 69 | // index storing structure 70 | vector>> *range_nns; 71 | 72 | void getNeighborsByHeuristic2RecordPruned( 73 | std::priority_queue, 74 | std::vector>, 75 | CompareByFirst> &top_candidates, 76 | const size_t M, vector> *back_nns, 77 | const int end_pos_id) { 78 | if (top_candidates.size() < M) { 79 | return; 80 | } 81 | 82 | std::priority_queue> queue_closest; 83 | std::vector> return_list; 84 | while (top_candidates.size() > 0) { 85 | queue_closest.emplace(-top_candidates.top().first, 86 | top_candidates.top().second); 87 | top_candidates.pop(); 88 | } 89 | 90 | while (queue_closest.size()) { 91 | if (return_list.size() >= M) break; 92 | std::pair curent_pair = queue_closest.top(); 93 | dist_t dist_to_query = -curent_pair.first; 94 | queue_closest.pop(); 95 | bool good = true; 96 | 97 | for (std::pair second_pair : return_list) { 98 | dist_t curdist = fstdistfunc_(getDataByInternalId(second_pair.second), 99 | getDataByInternalId(curent_pair.second), 100 | dist_func_param_); 101 | if (curdist < dist_to_query) { 102 | good = false; 103 | break; 104 | } 105 | } 106 | if (good) { 107 | return_list.push_back(curent_pair); 108 | } else { 109 | // record pruned nns, store in range_nns 110 | int external_nn = this->getExternalLabel(curent_pair.second); 111 | if (external_nn != end_pos_id) { 112 | SegmentNeighbor1D pruned_nn(external_nn, dist_to_query, 113 | end_pos_id); 114 | back_nns->emplace_back(pruned_nn); 115 | } 116 | } 117 | } 118 | 119 | // add unvisited nns 120 | while (queue_closest.size()) { 121 | std::pair curent_pair = queue_closest.top(); 122 | int external_nn = this->getExternalLabel(curent_pair.second); 123 | queue_closest.pop(); 124 | 125 | if (external_nn != end_pos_id) { 126 | SegmentNeighbor1D pruned_nn(external_nn, -curent_pair.first, 127 | end_pos_id); 128 | back_nns->emplace_back(pruned_nn); 129 | } 130 | } 131 | 132 | for (std::pair curent_pair : return_list) { 133 | top_candidates.emplace(-curent_pair.first, curent_pair.second); 134 | } 135 | } 136 | 137 | // since the order is important, SeRF use the external_id rather than the 138 | // inernal_id, but right now SeRF only supports building in one thread, so 139 | // acutally current external_id is equal to internal_id(cur_c). 140 | virtual tableint mutuallyConnectNewElementLevel0( 141 | const void *data_point, tableint cur_c, 142 | std::priority_queue, 143 | std::vector>, 144 | CompareByFirst> &top_candidates, 145 | int level, bool isUpdate) { 146 | size_t Mcurmax = this->maxM0_; 147 | getNeighborsByHeuristic2(top_candidates, this->M_); 148 | if (top_candidates.size() > this->M_) 149 | throw std::runtime_error( 150 | "Should be not be more than M_ candidates returned by the " 151 | "heuristic"); 152 | 153 | // forward neighbors in top candidates 154 | int external_id = this->getExternalLabel(cur_c); 155 | 156 | std::vector selectedNeighbors; 157 | selectedNeighbors.reserve(this->M_); 158 | while (top_candidates.size() > 0) { 159 | selectedNeighbors.push_back(top_candidates.top().second); 160 | 161 | top_candidates.pop(); 162 | } 163 | 164 | tableint next_closest_entry_point = selectedNeighbors.back(); 165 | 166 | { 167 | linklistsizeint *ll_cur; 168 | if (level == 0) 169 | ll_cur = this->get_linklist0(cur_c); 170 | else 171 | ll_cur = this->get_linklist(cur_c, level); 172 | 173 | if (*ll_cur && !isUpdate) { 174 | throw std::runtime_error( 175 | "The newly inserted element should have blank link list"); 176 | } 177 | this->setListCount(ll_cur, selectedNeighbors.size()); 178 | tableint *data = (tableint *)(ll_cur + 1); 179 | for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { 180 | if (data[idx] && !isUpdate) 181 | throw std::runtime_error("Possible memory corruption"); 182 | if (level > this->element_levels_[selectedNeighbors[idx]]) 183 | throw std::runtime_error( 184 | "Trying to make a link on a non-existent level"); 185 | 186 | data[idx] = selectedNeighbors[idx]; 187 | } 188 | } 189 | 190 | for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { 191 | // only keep half m, for removing duplicates for wiki-image and yt8m-audio 192 | if ((dataset_name == "wiki-image" || dataset_name == "yt8m-audio") && 193 | (idx > maxM0_ / 2)) { 194 | break; 195 | } 196 | std::unique_lock lock( 197 | this->link_list_locks_[selectedNeighbors[idx]]); 198 | 199 | linklistsizeint *ll_other; 200 | if (level == 0) 201 | ll_other = this->get_linklist0(selectedNeighbors[idx]); 202 | else 203 | ll_other = this->get_linklist(selectedNeighbors[idx], level); 204 | 205 | size_t sz_link_list_other = this->getListCount(ll_other); 206 | 207 | if (sz_link_list_other > Mcurmax) 208 | throw std::runtime_error("Bad value of sz_link_list_other"); 209 | if (selectedNeighbors[idx] == cur_c) 210 | throw std::runtime_error("Trying to connect an element to itself"); 211 | if (level > this->element_levels_[selectedNeighbors[idx]]) 212 | throw std::runtime_error( 213 | "Trying to make a link on a non-existent level"); 214 | 215 | tableint *data = (tableint *)(ll_other + 1); 216 | 217 | bool is_cur_c_present = false; 218 | if (isUpdate) { 219 | for (size_t j = 0; j < sz_link_list_other; j++) { 220 | if (data[j] == cur_c) { 221 | is_cur_c_present = true; 222 | break; 223 | } 224 | } 225 | } 226 | 227 | // If cur_c is already present in the neighboring connections of 228 | // `selectedNeighbors[idx]` then no need to modify any connections or 229 | // run the heuristics. 230 | if (!is_cur_c_present) { 231 | if (sz_link_list_other < Mcurmax) { 232 | data[sz_link_list_other] = cur_c; 233 | this->setListCount(ll_other, sz_link_list_other + 1); 234 | 235 | } else { 236 | // finding the "weakest" element to replace it with the new one 237 | dist_t d_max = 238 | fstdistfunc_(this->getDataByInternalId(cur_c), 239 | this->getDataByInternalId(selectedNeighbors[idx]), 240 | this->dist_func_param_); 241 | // Heuristic: 242 | std::priority_queue, 243 | std::vector>, 244 | CompareByFirst> 245 | candidates; 246 | candidates.emplace(d_max, cur_c); 247 | 248 | for (size_t j = 0; j < sz_link_list_other; j++) { 249 | candidates.emplace( 250 | fstdistfunc_(this->getDataByInternalId(data[j]), 251 | this->getDataByInternalId(selectedNeighbors[idx]), 252 | this->dist_func_param_), 253 | data[j]); 254 | } 255 | 256 | // TODO: add mutex to support parallel 257 | // auto back_nns = &range_nns->at(selectedNeighbors[idx]); 258 | auto back_nns = 259 | &range_nns->at(this->getExternalLabel(selectedNeighbors[idx])); 260 | getNeighborsByHeuristic2RecordPruned(candidates, Mcurmax, back_nns, 261 | external_id); 262 | int indx = 0; 263 | while (candidates.size() > 0) { 264 | data[indx] = candidates.top().second; 265 | candidates.pop(); 266 | indx++; 267 | } 268 | 269 | this->setListCount(ll_other, indx); 270 | // Nearest K: 271 | /*int indx = -1; 272 | for (int j = 0; j < sz_link_list_other; j++) { 273 | dist_t d = fstdistfunc_(getDataByInternalId(data[j]), 274 | getDataByInternalId(rez[idx]), dist_func_param_); if (d > d_max) { 275 | indx = j; 276 | d_max = d; 277 | } 278 | } 279 | if (indx >= 0) { 280 | data[indx] = cur_c; 281 | } */ 282 | } 283 | } 284 | } 285 | 286 | return next_closest_entry_point; 287 | } 288 | }; 289 | 290 | template 291 | class IndexSegmentGraph1D : public BaseIndex { 292 | public: 293 | vector>> indexed_arr; 294 | 295 | IndexSegmentGraph1D(base_hnsw::SpaceInterface *s, 296 | const DataWrapper *data) 297 | : BaseIndex(data) { 298 | fstdistfunc_ = s->get_dist_func(); 299 | dist_func_param_ = s->get_dist_func_param(); 300 | index_info = new IndexInfo(); 301 | index_info->index_version_type = "IndexSegmentGraph1D"; 302 | } 303 | 304 | base_hnsw::DISTFUNC fstdistfunc_; 305 | void *dist_func_param_; 306 | 307 | VisitedListPool *visited_list_pool_; 308 | IndexInfo *index_info; 309 | 310 | void printOnebatch(int pos = -1) { 311 | if (pos == -1) { 312 | pos = data_wrapper->data_size / 2; 313 | } 314 | 315 | cout << "nns at position: " << pos << endl; 316 | for (auto nns : indexed_arr[pos]) { 317 | cout << nns.id << "->" << nns.end_id << ")\n" << endl; 318 | } 319 | cout << endl; 320 | } 321 | 322 | void countNeighbors() { 323 | if (!indexed_arr.empty()) 324 | for (unsigned j = 0; j < indexed_arr.size(); j++) { 325 | int temp_size = 0; 326 | temp_size += indexed_arr[j].size(); 327 | index_info->nodes_amount += temp_size; 328 | } 329 | index_info->avg_forward_nns = 330 | index_info->nodes_amount / (float)data_wrapper->data_size; 331 | if (isLog) { 332 | cout << "Avg. nn #: " 333 | << index_info->nodes_amount / (float)data_wrapper->data_size << endl; 334 | } 335 | 336 | index_info->avg_reverse_nns = 0; 337 | } 338 | 339 | void buildIndex(const IndexParams *index_params) override { 340 | cout << "Building Index using " << index_info->index_version_type << endl; 341 | timeval tt1, tt2; 342 | visited_list_pool_ = 343 | new base_hnsw::VisitedListPool(1, data_wrapper->data_size); 344 | 345 | // build HNSW 346 | L2Space space(data_wrapper->data_dim); 347 | SegmentGraph1DHNSW hnsw( 348 | *index_params, &space, 2 * data_wrapper->data_size, 349 | data_wrapper->dataset, index_params->K, index_params->ef_construction, 350 | index_params->random_seed); 351 | 352 | indexed_arr.clear(); 353 | indexed_arr.resize(data_wrapper->data_size); 354 | 355 | hnsw.range_nns = &indexed_arr; 356 | 357 | gettimeofday(&tt1, NULL); 358 | 359 | // multi-thread also work, but not guaranteed as the paper 360 | #ifndef NO_PARALLEL_BUILD 361 | #pragma omp parallel for schedule(monotonic : dynamic) 362 | for (size_t i = 0; i < data_wrapper->data_size; ++i) { 363 | hnsw.addPoint(data_wrapper->nodes.at(i).data(), i); 364 | } 365 | #else 366 | for (size_t i = 0; i < data_wrapper->data_size; ++i) { 367 | hnsw.addPoint(data_wrapper->nodes.at(i).data(), i); 368 | } 369 | #endif 370 | 371 | for (size_t i = 0; i < data_wrapper->data_size; ++i) { 372 | // insert not pruned hnsw graph back 373 | hnsw.get_linklist0(i); 374 | linklistsizeint *ll_cur; 375 | ll_cur = hnsw.get_linklist0(i); 376 | size_t link_list_count = hnsw.getListCount(ll_cur); 377 | tableint *data = (tableint *)(ll_cur + 1); 378 | 379 | for (size_t j = 0; j < link_list_count; j++) { 380 | int node_id = hnsw.getExternalLabel(data[j]); 381 | SegmentNeighbor1D nn(node_id, 0, node_id); 382 | indexed_arr.at(i).emplace_back(nn); 383 | } 384 | } 385 | logTime(tt1, tt2, "Construct Time"); 386 | gettimeofday(&tt2, NULL); 387 | index_info->index_time = CountTime(tt1, tt2); 388 | // count neighbors number 389 | countNeighbors(); 390 | 391 | if (index_params->print_one_batch) { 392 | printOnebatch(); 393 | } 394 | } 395 | 396 | // range filtering search, only calculate distance on on-range nodes. 397 | vector rangeFilteringSearchInRange( 398 | const SearchParams *search_params, SearchInfo *search_info, 399 | const vector &query, 400 | const std::pair query_bound) override { 401 | // timeval tt1, tt2; 402 | timeval tt3, tt4; 403 | 404 | VisitedList *vl = visited_list_pool_->getFreeVisitedList(); 405 | vl_type *visited_array = vl->mass; 406 | vl_type visited_array_tag = vl->curV; 407 | float lower_bound = std::numeric_limits::max(); 408 | std::priority_queue> top_candidates; 409 | std::priority_queue> candidate_set; 410 | 411 | search_info->total_comparison = 0; 412 | search_info->internal_search_time = 0; 413 | search_info->cal_dist_time = 0; 414 | search_info->fetch_nns_time = 0; 415 | num_search_comparison = 0; 416 | gettimeofday(&tt3, NULL); 417 | 418 | // three enters, SeRF disgard the hierarchical structure of HNSW 419 | vector enter_list; 420 | { 421 | int lbound = query_bound.first; 422 | int interval = (query_bound.second - lbound) / 3; 423 | for (size_t i = 0; i < 3; i++) { 424 | int point = lbound + interval * i; 425 | float dist = fstdistfunc_( 426 | query.data(), data_wrapper->nodes[point].data(), dist_func_param_); 427 | candidate_set.push(make_pair(-dist, point)); 428 | enter_list.emplace_back(point); 429 | visited_array[point] = visited_array_tag; 430 | } 431 | } 432 | 433 | // size_t hop_counter = 0; 434 | 435 | while (!candidate_set.empty()) { 436 | std::pair current_node_pair = candidate_set.top(); 437 | int current_node_id = current_node_pair.second; 438 | 439 | if (-current_node_pair.first > lower_bound) { 440 | break; 441 | } 442 | 443 | #ifdef LOG_DEBUG_MODE 444 | cout << "current node: " << current_node_pair.second << " -- " 445 | << -current_node_pair.first << endl; 446 | #endif 447 | 448 | // hop_counter++; 449 | candidate_set.pop(); 450 | auto neighbor_it = indexed_arr.at(current_node_id).begin(); 451 | 452 | // gettimeofday(&tt1, NULL); 453 | 454 | while (neighbor_it != indexed_arr[current_node_id].end()) { 455 | if ((neighbor_it->id < query_bound.second) && 456 | (neighbor_it->end_id == neighbor_it->id || 457 | neighbor_it->end_id >= query_bound.second)) { 458 | int candidate_id = neighbor_it->id; 459 | 460 | if (!(visited_array[candidate_id] == visited_array_tag)) { 461 | visited_array[candidate_id] = visited_array_tag; 462 | float dist = fstdistfunc_(query.data(), 463 | data_wrapper->nodes[candidate_id].data(), 464 | dist_func_param_); 465 | 466 | num_search_comparison++; 467 | if (top_candidates.size() < search_params->search_ef || 468 | lower_bound > dist) { 469 | candidate_set.push(make_pair(-dist, candidate_id)); 470 | top_candidates.push(make_pair(dist, candidate_id)); 471 | if (top_candidates.size() > search_params->search_ef) { 472 | top_candidates.pop(); 473 | } 474 | if (!top_candidates.empty()) { 475 | lower_bound = top_candidates.top().first; 476 | } 477 | } 478 | } 479 | } 480 | neighbor_it++; 481 | } 482 | // gettimeofday(&tt2, NULL); 483 | // AccumulateTime(tt1, tt2, search_info->cal_dist_time); 484 | } 485 | 486 | vector res; 487 | while (top_candidates.size() > search_params->query_K) { 488 | top_candidates.pop(); 489 | } 490 | 491 | while (!top_candidates.empty()) { 492 | res.emplace_back(top_candidates.top().second); 493 | top_candidates.pop(); 494 | } 495 | search_info->total_comparison += num_search_comparison; 496 | 497 | #ifdef LOG_DEBUG_MODE 498 | print_set(res); 499 | cout << l_bound << "," << r_bound << endl; 500 | assert(false); 501 | #endif 502 | visited_list_pool_->releaseVisitedList(vl); 503 | 504 | gettimeofday(&tt4, NULL); 505 | CountTime(tt3, tt4, search_info->internal_search_time); 506 | 507 | return res; 508 | } 509 | 510 | // also calculate outbount dists, similar to knn-first 511 | // This is bad for half bounded search. 512 | vector rangeFilteringSearchOutBound( 513 | const SearchParams *search_params, SearchInfo *search_info, 514 | const vector &query, 515 | const std::pair query_bound) override { 516 | // timeval tt1, tt2; 517 | timeval tt3, tt4; 518 | 519 | VisitedList *vl = visited_list_pool_->getFreeVisitedList(); 520 | vl_type *visited_array = vl->mass; 521 | vl_type visited_array_tag = vl->curV; 522 | float lower_bound = std::numeric_limits::max(); 523 | std::priority_queue> top_candidates; 524 | std::priority_queue> candidate_set; 525 | 526 | search_info->total_comparison = 0; 527 | search_info->internal_search_time = 0; 528 | search_info->cal_dist_time = 0; 529 | search_info->fetch_nns_time = 0; 530 | num_search_comparison = 0; 531 | // finding enters 532 | vector enter_list; 533 | { 534 | int lbound = query_bound.first; 535 | int interval = (query_bound.second - lbound) / 3; 536 | for (size_t i = 0; i < 3; i++) { 537 | int point = lbound + interval * i; 538 | float dist = fstdistfunc_( 539 | query.data(), data_wrapper->nodes[point].data(), dist_func_param_); 540 | candidate_set.push(make_pair(-dist, point)); 541 | enter_list.emplace_back(point); 542 | visited_array[point] = visited_array_tag; 543 | } 544 | } 545 | gettimeofday(&tt3, NULL); 546 | 547 | // size_t hop_counter = 0; 548 | 549 | while (!candidate_set.empty()) { 550 | std::pair current_node_pair = candidate_set.top(); 551 | int current_node_id = current_node_pair.second; 552 | if (-current_node_pair.first > lower_bound) { 553 | break; 554 | } 555 | 556 | // hop_counter++; 557 | 558 | candidate_set.pop(); 559 | 560 | auto neighbor_it = indexed_arr.at(current_node_id).begin(); 561 | // gettimeofday(&tt1, NULL); 562 | 563 | while (neighbor_it != indexed_arr[current_node_id].end()) { 564 | if ((neighbor_it->id < query_bound.second)) { 565 | int candidate_id = neighbor_it->id; 566 | 567 | if (!(visited_array[candidate_id] == visited_array_tag)) { 568 | visited_array[candidate_id] = visited_array_tag; 569 | float dist = fstdistfunc_(query.data(), 570 | data_wrapper->nodes[candidate_id].data(), 571 | dist_func_param_); 572 | 573 | num_search_comparison++; 574 | if (top_candidates.size() < search_params->search_ef || 575 | lower_bound > dist) { 576 | candidate_set.emplace(-dist, candidate_id); 577 | // add to top_candidates only in range 578 | if (candidate_id <= query_bound.second && 579 | candidate_id >= query_bound.first) { 580 | top_candidates.emplace(dist, candidate_id); 581 | if (top_candidates.size() > search_params->search_ef) { 582 | top_candidates.pop(); 583 | } 584 | if (!top_candidates.empty()) { 585 | lower_bound = top_candidates.top().first; 586 | } 587 | } 588 | } 589 | } 590 | } 591 | neighbor_it++; 592 | } 593 | 594 | // gettimeofday(&tt2, NULL); 595 | // AccumulateTime(tt1, tt2, search_info->cal_dist_time); 596 | } 597 | 598 | vector res; 599 | while (top_candidates.size() > search_params->query_K) { 600 | top_candidates.pop(); 601 | } 602 | 603 | while (!top_candidates.empty()) { 604 | res.emplace_back(top_candidates.top().second); 605 | top_candidates.pop(); 606 | } 607 | search_info->total_comparison += num_search_comparison; 608 | 609 | #ifdef LOG_DEBUG_MODE 610 | print_set(res); 611 | cout << l_bound << "," << r_bound << endl; 612 | assert(false); 613 | #endif 614 | visited_list_pool_->releaseVisitedList(vl); 615 | 616 | gettimeofday(&tt4, NULL); 617 | CountTime(tt3, tt4, search_info->internal_search_time); 618 | 619 | return res; 620 | } 621 | 622 | void save(const string &save_path) { 623 | std::ofstream output(save_path, std::ios::binary); 624 | unsigned counter = 0; 625 | for (auto &segment : indexed_arr) { 626 | base_hnsw::writeBinaryPOD(output, (int)segment.size()); 627 | counter++; 628 | for (auto &nn : segment) { 629 | base_hnsw::writeBinaryPOD(output, nn.id); 630 | base_hnsw::writeBinaryPOD(output, nn.end_id); 631 | counter += 2; 632 | } 633 | } 634 | cout << "Total write " << counter << " (int) to file " << save_path << endl; 635 | } 636 | 637 | void load(const string &load_path) { 638 | std::ifstream input(load_path, std::ios::binary); 639 | if (!input.is_open()) throw std::runtime_error("Cannot open file"); 640 | cout << sizeof(size_t) << endl; 641 | indexed_arr.clear(); 642 | indexed_arr.reserve(data_wrapper->data_size); 643 | int nn_num; 644 | int id; 645 | int end_id; 646 | for (size_t i = 0; i < data_wrapper->data_size; i++) { 647 | base_hnsw::readBinaryPOD(input, nn_num); 648 | 649 | vector> neighbors; 650 | for (size_t j = 0; j < nn_num; j++) { 651 | base_hnsw::readBinaryPOD(input, id); 652 | base_hnsw::readBinaryPOD(input, end_id); 653 | SegmentNeighbor1D nn(id, end_id); 654 | neighbors.emplace_back(nn); 655 | } 656 | indexed_arr.emplace_back(neighbors); 657 | } 658 | // printOnebatch(); 659 | countNeighbors(); 660 | cout << "Total # of neighbors: " << index_info->nodes_amount << endl; 661 | } 662 | }; 663 | 664 | } // namespace SeRF --------------------------------------------------------------------------------