├── .clang-format ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── data ├── SIFT.100.bin ├── SIFT.10K.bin ├── mnist.scale.1K.cws.bvecs └── mnist.scale.cws.bvecs ├── include ├── abort_if.hpp ├── array_index.hpp ├── art_index.hpp ├── bit_tools.hpp ├── cmd_line_parser │ └── parser.hpp ├── dyft_factors.hpp ├── dyft_interface.hpp ├── gv_index.hpp ├── ham_tables.hpp ├── hms1dv_index.hpp ├── hms1v_index.hpp ├── io.hpp ├── mart_array_dense.hpp ├── mart_array_full.hpp ├── mart_array_sparse.hpp ├── mart_common.hpp ├── mart_index.hpp ├── mi_frame.hpp ├── misc.hpp ├── sparse_group.hpp ├── sparse_table.hpp ├── splitmix.hpp ├── statistic_reporter.hpp ├── timer.hpp ├── tinyformat │ └── tinyformat.h ├── vcode_array.hpp └── vcode_tools.hpp └── src ├── CMakeLists.txt ├── build_index_bin.cpp ├── build_index_int.cpp ├── bvecs_to_bin.cpp ├── gen_uniform.cpp ├── precompute.cpp ├── range_search_bin.cpp ├── range_search_int.cpp ├── sample_bin.cpp ├── sample_int.cpp └── simhash.cpp /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: Google 4 | AccessModifierOffset: -2 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveAssignments: false 7 | AlignConsecutiveDeclarations: false 8 | AlignEscapedNewlines: Left 9 | AlignOperands: true 10 | AlignTrailingComments: false 11 | AllowAllParametersOfDeclarationOnNextLine: true 12 | AllowShortBlocksOnASingleLine: false 13 | AllowShortCaseLabelsOnASingleLine: false 14 | AllowShortFunctionsOnASingleLine: Empty 15 | AllowShortIfStatementsOnASingleLine: true 16 | AllowShortLoopsOnASingleLine: true 17 | AlwaysBreakAfterDefinitionReturnType: None 18 | AlwaysBreakAfterReturnType: None 19 | AlwaysBreakBeforeMultilineStrings: true 20 | AlwaysBreakTemplateDeclarations: true 21 | BinPackArguments: true 22 | BinPackParameters: true 23 | BraceWrapping: 24 | AfterClass: false 25 | AfterControlStatement: false 26 | AfterEnum: false 27 | AfterFunction: false 28 | AfterNamespace: false 29 | AfterObjCDeclaration: false 30 | AfterStruct: false 31 | AfterUnion: false 32 | AfterExternBlock: false 33 | BeforeCatch: false 34 | BeforeElse: false 35 | IndentBraces: false 36 | SplitEmptyFunction: true 37 | SplitEmptyRecord: true 38 | SplitEmptyNamespace: true 39 | BreakBeforeBinaryOperators: None 40 | BreakBeforeBraces: Attach 41 | BreakBeforeInheritanceComma: false 42 | BreakBeforeTernaryOperators: true 43 | BreakConstructorInitializersBeforeComma: false 44 | BreakConstructorInitializers: BeforeColon 45 | BreakAfterJavaFieldAnnotations: false 46 | BreakStringLiterals: true 47 | ColumnLimit: 120 48 | CommentPragmas: '^ IWYU pragma:' 49 | CompactNamespaces: false 50 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 51 | ConstructorInitializerIndentWidth: 4 52 | ContinuationIndentWidth: 4 53 | Cpp11BracedListStyle: true 54 | DerivePointerAlignment: false 55 | DisableFormat: false 56 | ExperimentalAutoDetectBinPacking: false 57 | FixNamespaceComments: true 58 | ForEachMacros: 59 | - foreach 60 | - Q_FOREACH 61 | - BOOST_FOREACH 62 | IncludeBlocks: Preserve 63 | IncludeCategories: 64 | - Regex: '^' 65 | Priority: 2 66 | - Regex: '^<.*\.h>' 67 | Priority: 1 68 | - Regex: '^<.*' 69 | Priority: 2 70 | - Regex: '.*' 71 | Priority: 3 72 | IncludeIsMainRegex: '([-_](test|unittest))?$' 73 | IndentCaseLabels: true 74 | IndentPPDirectives: None 75 | IndentWidth: 4 76 | IndentWrappedFunctionNames: false 77 | JavaScriptQuotes: Leave 78 | JavaScriptWrapImports: true 79 | KeepEmptyLinesAtTheStartOfBlocks: false 80 | MacroBlockBegin: '' 81 | MacroBlockEnd: '' 82 | MaxEmptyLinesToKeep: 1 83 | NamespaceIndentation: None 84 | ObjCBlockIndentWidth: 2 85 | ObjCSpaceAfterProperty: false 86 | ObjCSpaceBeforeProtocolList: false 87 | PenaltyBreakAssignment: 2 88 | PenaltyBreakBeforeFirstCallParameter: 1 89 | PenaltyBreakComment: 300 90 | PenaltyBreakFirstLessLess: 120 91 | PenaltyBreakString: 1000 92 | PenaltyExcessCharacter: 1000000 93 | PenaltyReturnTypeOnItsOwnLine: 200 94 | PointerAlignment: Left 95 | ReflowComments: true 96 | SortIncludes: true 97 | SortUsingDeclarations: true 98 | SpaceAfterCStyleCast: false 99 | SpaceAfterTemplateKeyword: true 100 | SpaceBeforeAssignmentOperators: true 101 | SpaceBeforeParens: ControlStatements 102 | SpaceInEmptyParentheses: false 103 | SpacesBeforeTrailingComments: 2 104 | SpacesInAngles: false 105 | SpacesInContainerLiterals: true 106 | SpacesInCStyleCastParentheses: false 107 | SpacesInParentheses: false 108 | SpacesInSquareBrackets: false 109 | Standard: Auto 110 | TabWidth: 8 111 | UseTab: Never 112 | ... 113 | 114 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | # Mydef 35 | build*/ 36 | scripts 37 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(dyft) 3 | 4 | if(NOT CMAKE_BUILD_TYPE) 5 | # set(CMAKE_BUILD_TYPE "Debug") 6 | set(CMAKE_BUILD_TYPE "Release") 7 | endif(NOT CMAKE_BUILD_TYPE) 8 | 9 | # C++17 compiler check 10 | if (CMAKE_CXX_COMPILER MATCHES ".*clang.*" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") 11 | set(CMAKE_COMPILER_IS_CLANGXX 1) 12 | endif () 13 | if (CMAKE_CXX_COMPILER_ID STREQUAL "Intel") 14 | set(CMAKE_COMPILER_IS_INTEL 1) 15 | endif () 16 | if ((CMAKE_COMPILER_IS_GNUCXX AND ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 7.0) OR (CMAKE_COMPILER_IS_CLANGXX AND ${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.0)) 17 | message(FATAL_ERROR "Your C++ compiler does not support C++17. Please install g++ 7.0 (or greater) or clang 4.0 (or greater)") 18 | else () 19 | message(STATUS "Compiler is recent enough to support C++17.") 20 | endif () 21 | 22 | # Boost 23 | set(Boost_USE_MULTITHREADED ON) # https://gitlab.kitware.com/cmake/cmake/issues/19714 24 | set(Boost_USE_STATIC_LIBS ON) 25 | find_package(Boost REQUIRED COMPONENTS system filesystem iostreams timer date_time) 26 | include_directories(${Boost_INCLUDE_DIRS}) 27 | 28 | # SSE Popcnt 29 | set(BUILTIN_POPCNT 0) 30 | 31 | if (DISABLE_SSE4_2) 32 | message(STATUS "sse4.2 disabled") 33 | elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") 34 | # Use /proc/cpuinfo to get the information 35 | file(STRINGS "/proc/cpuinfo" _cpuinfo) 36 | if(_cpuinfo MATCHES "(sse4_2)|(sse4a)") 37 | set(BUILTIN_POPCNT 1) 38 | endif() 39 | elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") 40 | execute_process(COMMAND sysctl -n machdep.cpu.features OUTPUT_VARIABLE _cpuinfo OUTPUT_STRIP_TRAILING_WHITESPACE) 41 | # message(STATUS "_cpuinfo is ${_cpuinfo}") 42 | if(_cpuinfo MATCHES "SSE4.2") 43 | set(BUILTIN_POPCNT 1) 44 | endif() 45 | endif() 46 | 47 | if(BUILTIN_POPCNT) 48 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2") 49 | endif() 50 | 51 | # AVX2 52 | set(SIMD_COMPARISON 0) 53 | 54 | if (DISABLE_AVX) 55 | message(STATUS "AVX disabled") 56 | elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") 57 | # # Use /proc/cpuinfo to get the information 58 | file(STRINGS "/proc/cpuinfo" _cpuinfo) 59 | if(_cpuinfo MATCHES "avx2") 60 | set(SIMD_COMPARISON 2) 61 | elseif(_cpuinfo MATCHES "avx") 62 | set(SIMD_COMPARISON 1) 63 | endif() 64 | elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") 65 | execute_process(COMMAND sysctl -n machdep.cpu.leaf7_features OUTPUT_VARIABLE _cpuinfo OUTPUT_STRIP_TRAILING_WHITESPACE) 66 | # message(STATUS "_cpuinfo is ${_cpuinfo}") 67 | if(_cpuinfo MATCHES "AVX2") 68 | set(SIMD_COMPARISON 2) 69 | else() 70 | execute_process(COMMAND sysctl -n machdep.cpu.features OUTPUT_VARIABLE _cpuinfo OUTPUT_STRIP_TRAILING_WHITESPACE) 71 | # message(STATUS "_cpuinfo is ${_cpuinfo}") 72 | if(_cpuinfo MATCHES "AVX1.0") 73 | set(SIMD_COMPARISON 1) 74 | endif() 75 | endif() 76 | endif() 77 | 78 | if(SIMD_COMPARISON EQUAL 1) 79 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx") 80 | elseif(SIMD_COMPARISON EQUAL 2) 81 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2") 82 | endif() 83 | 84 | 85 | # -fvisibility=hidden is enabled to ignore "ld: warning: direct access in function for boost::filesystem." 86 | # https://stackoverflow.com/questions/36567072/why-do-i-get-ld-warning-direct-access-in-main-to-global-weak-symbol-in-this 87 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++1z -pthread -fvisibility=hidden -Wall") 88 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -march=native -O3") 89 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address -fno-omit-frame-pointer -O0 -g -DDEBUG") 90 | 91 | message(STATUS "BUILD_TYPE is ${CMAKE_BUILD_TYPE}") 92 | message(STATUS "CXX_FLAGS are ${CMAKE_CXX_FLAGS}") 93 | message(STATUS "CXX_FLAGS_DEBUG are ${CMAKE_CXX_FLAGS_DEBUG}") 94 | message(STATUS "CXX_FLAGS_RELEASE are ${CMAKE_CXX_FLAGS_RELEASE}") 95 | 96 | include_directories(include) 97 | add_subdirectory(src) 98 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Shunsuke Kanda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Filter Trie (DyFT) 2 | 3 | This is an experimental library for data structures described in the following papers: 4 | 5 | - Shunsuke Kanda and Yasuo Tabei, [Dynamic Similarity Search on Integer Sketches](https://ieeexplore.ieee.org/document/9338383/), In *Proc. 20th IEEE ICDM*, pp 242–251, 2020 [[arXiv](https://arxiv.org/abs/2009.11559)] 6 | - Shunsuke Kanda and Yasuo Tabei, [DyFT: A Dynamic Similarity Search Method on Integer Sketches](https://link.springer.com/article/10.1007%2Fs10115-021-01611-2), *Knowledge and Information Systems*, 63, 2815–2840, 2021 [[SharedIt](https://rdcu.be/cxu1J)] 7 | 8 | ## Build instructions 9 | 10 | You can download and compile this library by the following commands: 11 | 12 | ```shell 13 | $ git clone https://github.com/kampersanda/dyft.git 14 | $ cd dyft 15 | $ mkdir build 16 | $ cd build 17 | $ cmake .. 18 | $ make -j 19 | ``` 20 | 21 | After the commands, the executables will be produced in `build/bin` directory. 22 | 23 | The code is written in C++17, so please install g++ >= 7.0 or clang >= 4.0. The following dependencies have to be installed to compile the library: CMake >= 3.0, for the build system, and Boost >= 1.42. 24 | 25 | The library employs the third-party libraries [cmd\_line\_parser](https://github.com/jermp/cmd_line_parser) and [tinyformat](https://github.com/c42f/tinyformat), whose header files are contained in this repository. 26 | 27 | ## Data structures 28 | 29 | The library contains some implementations of data structures in addition to our DyFT, which were used in the experiments (Section VI of our paper), as follows. 30 | 31 | ### DyFT family 32 | 33 | - `mart_index` is an implementation of DyFT with MART representation. 34 | - `art_index` is an implementation of DyFT with ART representation. 35 | - `array_index` is an implementation of DyFT with Array representation. 36 | - `mi_frame` is an implementation of DyFT+ (i.e., multi-index variant) for the index which is specified by the template argument `T` of a DyFT class. 37 | 38 | ### Others 39 | 40 | - `hms1v_index` is an implementation of [HmSearch 1-var (HSV)](https://doi.org/10.1145/2484838.2484842), which is designed for binary sketches. 41 | - `hms1dv_index` is an implementation of [HmSearch 1-del-var (HSD)](https://doi.org/10.1145/2484838.2484842), which is designed for integer sketches. 42 | - `gv_index` is an implementation of multi-index hashing for integer sketches, which follows an idea by [Gog and Venturini](https://doi.org/10.1145/2911451.2911523). 43 | 44 | ## Input data format 45 | 46 | ### Binary sketches 47 | 48 | Binary sketches should be stored in binary format, where each sketch is 64 bits of size. The executables provided by the library can indicate the number of dimensions (32 or 64) to evaluate through argument `-N`, and use the first `N` dimensions. 49 | 50 | You can use `src/simhash.cpp` to generate such a dataset using [Charikar’s simhash](https://doi.org/10.1145/509907.509965) algorithm. 51 | 52 | ### Integer sketches 53 | 54 | Integer sketches should be stored in [TEXMEX's bvecs format](http://corpus-texmex.irisa.fr/). That is, the dimension number and features for each sketch are interleaved, where the number is 4 bytes of size and each feature is 1 byte of size. In the same manner, the executables provided by the library can indicate the dimension number (32 or 64) to evaluate through argument `-N`. And, argument `-B` can indicate the number of lowest bits to evaluate for each feature (1 to 8). 55 | 56 | You can use [consistent\_weighted\_sampling](https://github.com/kampersanda/consistent_weighted_sampling) to generate such a dataset using the [GCWS](https://doi.org/10.1145/3097983.3098081) algorithm. 57 | 58 | ## Usage 59 | 60 | In the `data` directory, there are the tiny datasets: 61 | 62 | - `SIFT.10K.bin` and `SIFT.100.bin` are binary sketches generated from [ANN\_SIFT10K](http://corpus-texmex.irisa.fr/) using Charikar’s simhash algorithm. 63 | - `mnist.scale.cws.bvecs` and `mnist.scale.1K.cws.bvecs` are 1-byte integer sketches generated from [minst](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#mnist) using the GCWS algorithm. 64 | 65 | By using these datasets, you can test the data structure as follows. 66 | 67 | ### Benchmark for range search 68 | 69 | The executables `range_search_*` are used to analyze the performance of range search time for each data structure. Given a dataset and query files, the executable inserts data sketches one by one and performs range search for the queries when the number of inserted sketches is power of ten. The benchmark results are output as a json file. 70 | 71 | #### Example: DyFT with MART on binary sketches 72 | 73 | ```shell 74 | $ ./bin/range_search_bin_dyft ../data/SIFT.10K.bin ../data/SIFT.100.bin -o results -R 2 -N 32 -A mart -K 1 75 | ``` 76 | 77 | The command tests range search with radius `R=2` for `mart_index` of single index (`K=1`) through dataset `SIFT.10K.bin` and queryset `SIFT.100.bin`, where the number of dimensions is `N=32` (i.e., the first 32 dimensions are used). The benchmark result is output in `results` directory, where the file name is assigned automatically. 78 | 79 | #### Example: DyFT+ with ART on integer sketches 80 | 81 | ```shell 82 | $ ./bin/range_search_int_dyft ../data/mnist.scale.cws.bvecs ../data/mnist.scale.1K.cws.bvecs -o results -R 4 -N 64 -B 4 -A art -K 3 83 | ``` 84 | 85 | The command tests range search with radius `R=4` for `mi_frame` of `K=3` blocks through dataset `mnist.scale.cws.bvecs` and queryset `mnist.scale.1K.cws.bvecs`, where the number of dimensions is `N=64` and the lowest `B=4` bits are used for each feature. 86 | 87 | #### Example: GV on integer sketches 88 | 89 | ```shell 90 | $ ./bin/range_search_int_gv ../data/mnist.scale.cws.bvecs ../data/mnist.scale.1K.cws.bvecs -o results -R 6 -N 64 -B 8 91 | ``` 92 | 93 | The command tests range search with radius `R=6` for `gv_index` through dataset `mnist.scale.cws.bvecs` and queryset `mnist.scale.1K.cws.bvecs`, where the number of dimensions is `N=64` and all the `B=8` bits are used for each feature. 94 | 95 | ### Benchmark for construction 96 | 97 | The executables `build_index_*` are used to analyze the performance of insertion time and process size for each data structure. Given a dataset, the executable inserts data sketches one by one and reports the statistics when the number of inserted sketches is power of ten. The benchmark results are output as a json file. 98 | 99 | #### Example: DyFT with Array on binary sketches 100 | 101 | ```shell 102 | $ ./bin/build_index_bin_dyft ../data/SIFT.10K.bin -o results -R 2 -N 32 -A array -K 1 103 | ``` 104 | 105 | The command builds `array_index` of single index (`K=1`) on radius `R=2` for dataset `SIFT.10K.bin`, where the number of dimension is `N=32`. 106 | 107 | #### Example: DyFT+ with MART on integer sketches 108 | 109 | ```shell 110 | $ ./bin/build_index_int_dyft ../data/mnist.scale.cws.bvecs -o results -R 4 -N 64 -B 4 -A mart -K 0 111 | ``` 112 | 113 | The command builds `mi_frame` with a reasonable number of blocks on radius `R=4` for dataset `mnist.scale.cws.bvecs`, where the number of dimension is `N=64` and the lowest `B=4` bits are used for each feature. 114 | 115 | When `K` is set to 0, the number of blocks is set to ⌊`R`/2⌋+1 following the idea of GV. 116 | 117 | #### Example: HSV on binary sketches 118 | 119 | ```shell 120 | $ ./bin/build_index_bin_hms1v ../data/SIFT.10K.bin -o results -R 6 -N 64 121 | ``` 122 | 123 | The command builds `hms1v_index` on radius `R=6` for dataset `SIFT.10K.bin`, where the number of dimension is `N=64`. 124 | 125 | ## Licensing 126 | 127 | This library is free software provided under [MIT License](https://github.com/kampersanda/dyft/blob/master/LICENSE). 128 | 129 | If you use the library, please cite the following paper: 130 | 131 | ``` 132 | @inproceedings{kanda2020dynamic, 133 | author = {Kanda, Shunsuke and Tabei, Yasuo}, 134 | title = {Dynamic Similarity Search on Integer Sketches}, 135 | booktitle = {Proceedings of the 20th IEEE International Conference on Data Mining (ICDM)}, 136 | pages={242-251}, 137 | year = {2020} 138 | } 139 | ``` 140 | 141 | -------------------------------------------------------------------------------- /data/SIFT.100.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kampersanda/dyft/9c675b60817542a3fccbfb76567f9b6a4ec935a9/data/SIFT.100.bin -------------------------------------------------------------------------------- /data/SIFT.10K.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kampersanda/dyft/9c675b60817542a3fccbfb76567f9b6a4ec935a9/data/SIFT.10K.bin -------------------------------------------------------------------------------- /data/mnist.scale.1K.cws.bvecs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kampersanda/dyft/9c675b60817542a3fccbfb76567f9b6a4ec935a9/data/mnist.scale.1K.cws.bvecs -------------------------------------------------------------------------------- /data/mnist.scale.cws.bvecs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kampersanda/dyft/9c675b60817542a3fccbfb76567f9b6a4ec935a9/data/mnist.scale.cws.bvecs -------------------------------------------------------------------------------- /include/abort_if.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define ABORT_IF_(x, y, z) \ 7 | if ((x)) { \ 8 | std::cerr << "\033[0;31mERROR in " << __FILE__ << " at line " << std::to_string(__LINE__) << ", if " << (#x) \ 9 | << " for " << std::to_string(y) << " vs " << std::to_string(z) << "\033[0;0m" << std::endl; \ 10 | abort(); \ 11 | } 12 | 13 | #define ABORT_IF(x) \ 14 | if ((x)) { \ 15 | std::cerr << "\033[0;31mERROR in " << __FILE__ << " at line " << std::to_string(__LINE__) << ", if " << (#x) \ 16 | << "\033[0;0m" << std::endl; \ 17 | abort(); \ 18 | } 19 | 20 | #define ABORT_IF_EQ(x, y) ABORT_IF_((x) == (y), x, y) 21 | #define ABORT_IF_NE(x, y) ABORT_IF_((x) != (y), x, y) 22 | #define ABORT_IF_LE(x, y) ABORT_IF_((x) <= (y), x, y) 23 | #define ABORT_IF_LT(x, y) ABORT_IF_((x) < (y), x, y) 24 | #define ABORT_IF_GE(x, y) ABORT_IF_((x) >= (y), x, y) 25 | #define ABORT_IF_GT(x, y) ABORT_IF_((x) > (y), x, y) 26 | #define ABORT_IF_OUT(n, x, y) ABORT_IF_((n) < (x), n, x) ABORT_IF_((y) < (n), y, n) 27 | 28 | #ifdef NDEBUG 29 | #define DEBUG_ABORT_IF(x) 30 | #define DEBUG_ABORT_IF_EQ(x, y) 31 | #define DEBUG_ABORT_IF_NE(x, y) 32 | #define DEBUG_ABORT_IF_LE(x, y) 33 | #define DEBUG_ABORT_IF_LT(x, y) 34 | #define DEBUG_ABORT_IF_GE(x, y) 35 | #define DEBUG_ABORT_IF_GT(x, y) 36 | #define DEBUG_ABORT_IF_OUT(n, x, y) 37 | #else 38 | #define DEBUG_ABORT_IF(x) ABORT_IF((x)) 39 | #define DEBUG_ABORT_IF_EQ(x, y) ABORT_IF_((x) == (y), x, y) 40 | #define DEBUG_ABORT_IF_NE(x, y) ABORT_IF_((x) != (y), x, y) 41 | #define DEBUG_ABORT_IF_LE(x, y) ABORT_IF_((x) <= (y), x, y) 42 | #define DEBUG_ABORT_IF_LT(x, y) ABORT_IF_((x) < (y), x, y) 43 | #define DEBUG_ABORT_IF_GE(x, y) ABORT_IF_((x) >= (y), x, y) 44 | #define DEBUG_ABORT_IF_GT(x, y) ABORT_IF_((x) > (y), x, y) 45 | #define DEBUG_ABORT_IF_OUT(n, x, y) ABORT_IF_((n) < (x), n, x) ABORT_IF_((y) < (n), y, n) 46 | #endif 47 | -------------------------------------------------------------------------------- /include/array_index.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dyft_factors.hpp" 4 | #include "dyft_interface.hpp" 5 | #include "sparse_table.hpp" 6 | #include "vcode_array.hpp" 7 | 8 | namespace dyft { 9 | 10 | template 11 | class array_index : public dyft_interface { 12 | public: 13 | using vint_type = typename vcode_tools::vint_type; 14 | 15 | static constexpr int LEN = N; 16 | static constexpr int NOT_FOUND = std::numeric_limits::max(); 17 | 18 | static int get_chunks_size(int) { 19 | return LEN; 20 | } 21 | 22 | private: 23 | const vcode_array* m_database = nullptr; 24 | 25 | const int m_sigma = 0; 26 | const int m_radius = 0; 27 | const int m_splitthr = 0; 28 | const int* m_splitthrs_ptr = nullptr; 29 | 30 | // For computing trie cost 31 | const double* m_infacts_ptr = nullptr; 32 | const double* m_outfacts_ptr = nullptr; 33 | 34 | // Trie 35 | std::vector m_fanouts; 36 | sparse_table m_idseqs; 37 | uint32_t m_ids = 0; // num ids 38 | 39 | const int m_depth_beg = 0; 40 | const int m_depth_end = N; 41 | 42 | // double m_bf_cost = 0.0; 43 | double m_trie_cost = 0.0; 44 | double m_in_weight = 1.0; 45 | 46 | // For query processing 47 | struct state_type { 48 | int npos; 49 | int depth; 50 | int dist; 51 | }; 52 | std::vector m_states; 53 | 54 | // statistics 55 | size_t m_split_count = 0; 56 | 57 | public: 58 | explicit array_index(const vcode_array* database, int radius, int splitthr, double in_weight) 59 | : array_index(database, radius, 0, N, splitthr, in_weight) {} 60 | 61 | explicit array_index(const vcode_array* database, int radius, int depth_beg, int depth_end, // 62 | int splitthr, double in_weight) 63 | : m_database(database), m_sigma(1 << m_database->get_bits()), m_radius(radius), m_splitthr(splitthr), 64 | m_splitthrs_ptr(m_splitthr <= 0 ? dyft_factors::get_splitthrs(m_database->get_bits(), m_radius) : nullptr), 65 | m_infacts_ptr(m_splitthr <= 0 ? dyft_factors::get_infacts(m_database->get_bits(), m_radius) : nullptr), 66 | m_outfacts_ptr(m_splitthr <= 0 ? dyft_factors::get_outfacts(m_database->get_bits(), m_radius) : nullptr), 67 | m_fanouts(m_sigma, NOT_FOUND), m_depth_beg(depth_beg), m_depth_end(depth_end), m_in_weight(in_weight) { 68 | m_states.reserve(1 << 10); 69 | } 70 | 71 | uint32_t append() override { 72 | ABORT_IF_LE(m_database->get_size(), m_ids); 73 | 74 | const uint32_t new_id = m_ids++; 75 | const vint_type* vcode = m_database->access(new_id); 76 | 77 | int npos = 0, depth = m_depth_beg; 78 | 79 | while (true) { 80 | const int c = vcode_tools::get_int(vcode, depth++, get_bits()); 81 | const int cpos = npos + c; 82 | 83 | if (m_fanouts[cpos] == NOT_FOUND) { 84 | npos = cpos; 85 | m_fanouts[npos] = -1 * int(m_idseqs.size()); 86 | m_idseqs.push_back(); 87 | break; 88 | } 89 | 90 | if (m_fanouts[cpos] <= 0) { // Leaf? 91 | npos = cpos; 92 | break; 93 | } 94 | 95 | npos = m_fanouts[cpos] * m_sigma; 96 | } 97 | 98 | const int lpos = -1 * m_fanouts[npos]; 99 | m_idseqs.insert(lpos, new_id); 100 | 101 | if (m_outfacts_ptr) { 102 | m_trie_cost += m_outfacts_ptr[depth - m_depth_beg - 1]; 103 | } 104 | 105 | if (depth == m_depth_end) { 106 | return new_id; 107 | } 108 | 109 | const int cnt = int(m_idseqs.group_size(lpos)); 110 | 111 | if (m_splitthrs_ptr == nullptr) { 112 | if (m_splitthr <= cnt) { 113 | split_node(npos, depth); 114 | } 115 | } else { 116 | if (m_splitthrs_ptr[depth - 1] * m_in_weight <= cnt) { 117 | split_node(npos, depth); 118 | } 119 | } 120 | return new_id; 121 | } 122 | 123 | void range_search(const vint_type* vcode, const std::function& fn) override { 124 | return selects_ls() ? ls_search(vcode, fn) : trie_search(vcode, fn); 125 | } 126 | 127 | void trie_search(const vint_type* vcode, const std::function& fn) { 128 | m_states.clear(); 129 | m_states.push_back(state_type{0, m_depth_beg, 0}); 130 | 131 | int code[N]; 132 | for (int i = m_depth_beg; i < m_depth_end; i++) { 133 | code[i] = vcode_tools::get_int(vcode, i, get_bits()); 134 | } 135 | 136 | while (!m_states.empty()) { 137 | state_type s = m_states.back(); 138 | m_states.pop_back(); 139 | 140 | if (s.dist < m_radius) { 141 | const int c = code[s.depth]; 142 | 143 | for (int i = 0; i < m_sigma; ++i) { 144 | const int cpos = s.npos + i; 145 | 146 | if (m_fanouts[cpos] == NOT_FOUND) { 147 | continue; 148 | } 149 | 150 | if (m_fanouts[cpos] <= 0) { // leaf 151 | const int lpos = -1 * m_fanouts[cpos]; 152 | auto [bptr, eptr] = m_idseqs.access(lpos); 153 | for (auto it = bptr; it != eptr; ++it) { 154 | fn(*it); 155 | } 156 | } else { // internal 157 | if (i == c) { 158 | m_states.push_back(state_type{m_fanouts[cpos] * m_sigma, s.depth + 1, s.dist}); 159 | } else { 160 | m_states.push_back(state_type{m_fanouts[cpos] * m_sigma, s.depth + 1, s.dist + 1}); 161 | } 162 | } 163 | } 164 | } else { // exact match 165 | while (true) { 166 | const int c = code[s.depth]; 167 | const int cpos = s.npos + c; 168 | 169 | if (m_fanouts[cpos] == NOT_FOUND) { 170 | break; 171 | } 172 | 173 | if (m_fanouts[cpos] <= 0) { // leaf 174 | const int lpos = -1 * m_fanouts[cpos]; 175 | auto [bptr, eptr] = m_idseqs.access(lpos); 176 | for (auto it = bptr; it != eptr; ++it) { 177 | fn(*it); 178 | } 179 | break; 180 | } 181 | 182 | // internal 183 | s.npos = m_fanouts[cpos] * m_sigma; 184 | s.depth += 1; 185 | } 186 | } 187 | } 188 | } 189 | 190 | void ls_search(const vint_type* vcode, const std::function& fn) { 191 | for (uint32_t id = 0; id < get_size(); id++) { 192 | fn(id); 193 | } 194 | } 195 | 196 | uint32_t get_size() const override { 197 | return m_ids; 198 | } 199 | uint32_t get_leaves() const override { 200 | return m_idseqs.size(); 201 | } 202 | int get_bits() const override { 203 | return m_database->get_bits(); 204 | } 205 | 206 | double get_ls_cost() const override { 207 | return static_cast(m_ids) * get_bits(); 208 | } 209 | double get_trie_cost() const override { 210 | return m_trie_cost; 211 | } 212 | 213 | bool selects_ls() const override { 214 | return m_infacts_ptr and get_ls_cost() < get_trie_cost(); 215 | } 216 | 217 | size_t get_split_count() override { 218 | return m_split_count; 219 | } 220 | 221 | void innode_stats() override { 222 | tfm::printfln("innode_stats: %d", m_fanouts.size() / m_sigma); 223 | } 224 | void population_stats() override {} 225 | 226 | private: 227 | void split_node(const int npos, const int depth) { 228 | ABORT_IF_LT(0, m_fanouts[npos]); 229 | ABORT_IF_LT(depth, m_depth_beg + 1); 230 | 231 | m_split_count += 1; 232 | 233 | if (m_fanouts.size() == m_fanouts.capacity()) { 234 | m_fanouts.reserve(m_fanouts.size() * 2); 235 | } 236 | 237 | const int new_npos = static_cast(m_fanouts.size()); 238 | m_fanouts.resize(new_npos + m_sigma, NOT_FOUND); 239 | 240 | int lpos = -1 * m_fanouts[npos]; 241 | ABORT_IF_LE(m_idseqs.size(), size_t(lpos)); 242 | 243 | // update cost 244 | if (m_infacts_ptr) { 245 | const int _depth = depth - m_depth_beg - 1; 246 | const int g_size = m_idseqs.group_size(lpos); 247 | 248 | m_trie_cost -= g_size * m_outfacts_ptr[_depth]; 249 | m_trie_cost += (m_infacts_ptr[_depth] * m_in_weight) + (g_size * m_outfacts_ptr[_depth + 1]); 250 | } 251 | 252 | std::vector > idbufs(m_sigma); 253 | { 254 | const auto ids = m_idseqs.extract(lpos); 255 | for (uint32_t id : ids) { 256 | const vint_type* vcode = m_database->access(id); 257 | const int c = vcode_tools::get_int(vcode, depth, get_bits()); 258 | idbufs[c].push_back(id); 259 | } 260 | } 261 | 262 | for (int c = 0; c < m_sigma; c++) { 263 | if (idbufs[c].empty()) { 264 | continue; 265 | } 266 | 267 | const int new_cpos = new_npos + c; 268 | 269 | if (lpos != NOT_FOUND) { 270 | // the first new idseq 271 | m_fanouts[new_cpos] = -1 * lpos; 272 | m_idseqs.insert(lpos, idbufs[c]); 273 | lpos = NOT_FOUND; 274 | } else { 275 | m_fanouts[new_cpos] = -1 * int(m_idseqs.size()); 276 | m_idseqs.push_back(idbufs[c]); 277 | } 278 | } 279 | 280 | m_fanouts[npos] = new_npos / m_sigma; 281 | } 282 | }; 283 | 284 | } // namespace dyft 285 | -------------------------------------------------------------------------------- /include/art_index.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "dyft_factors.hpp" 4 | #include "dyft_interface.hpp" 5 | #include "mart_array_dense.hpp" 6 | #include "mart_array_full.hpp" 7 | #include "mart_array_sparse.hpp" 8 | #include "sparse_table.hpp" 9 | #include "vcode_array.hpp" 10 | 11 | namespace dyft { 12 | 13 | template 14 | class art_index : public dyft_interface { 15 | public: 16 | using vint_type = typename vcode_tools::vint_type; 17 | 18 | using array_4_type = mart_array_sparse<4, MART_NODE_4>; 19 | using array_16_type = mart_array_sparse<16, MART_NODE_16>; 20 | using array_48_type = mart_array_dense<48, MART_NODE_64>; // instead of MART_NODE_48 21 | using array_256_type = mart_array_full; 22 | 23 | static constexpr int LEN = N; 24 | static constexpr int NOT_FOUND = std::numeric_limits::max(); 25 | 26 | static int get_chunks_size(int) { 27 | return LEN; 28 | } 29 | 30 | private: 31 | const vcode_array* m_database = nullptr; 32 | 33 | const int m_sigma = 0; 34 | const int m_radius = 0; 35 | const int m_splitthr = 0; 36 | const int* m_splitthrs_ptr = nullptr; 37 | 38 | // For computing trie cost 39 | const double* m_infacts_ptr = nullptr; 40 | const double* m_outfacts_ptr = nullptr; 41 | 42 | // ART 43 | array_4_type m_array_4; 44 | array_16_type m_array_16; 45 | array_48_type m_array_48; 46 | array_256_type m_array_256; 47 | 48 | mart_pointer m_rootptr; 49 | 50 | // std::vector m_fanouts; 51 | sparse_table m_idseqs; 52 | uint32_t m_ids = 0; // num ids 53 | 54 | const int m_depth_beg = 0; 55 | const int m_depth_end = N; 56 | 57 | // double m_bf_cost = 0.0; 58 | double m_trie_cost = 0.0; 59 | double m_in_weight = 1.0; 60 | 61 | // For query processing 62 | struct state_type { 63 | mart_pointer nptr; 64 | int depth; 65 | int dist; 66 | }; 67 | std::vector m_states; 68 | 69 | array_4_type::searcher m_s4; 70 | array_16_type::searcher m_s16; 71 | array_48_type::searcher m_s48; 72 | array_256_type::searcher m_s256; 73 | std::vector m_edges; 74 | 75 | // statistics 76 | size_t m_split_count = 0; 77 | 78 | public: 79 | explicit art_index(const vcode_array* database, int radius, int splitthr, double in_weight) 80 | : art_index(database, radius, 0, N, splitthr, in_weight) {} 81 | 82 | explicit art_index(const vcode_array* database, int radius, int depth_beg, int depth_end, // 83 | int splitthr, double in_weight) 84 | : m_database(database), m_sigma(1 << m_database->get_bits()), m_radius(radius), m_splitthr(splitthr), 85 | m_splitthrs_ptr(m_splitthr <= 0 ? dyft_factors::get_splitthrs(m_database->get_bits(), m_radius) : nullptr), 86 | m_infacts_ptr(m_splitthr <= 0 ? dyft_factors::get_infacts(m_database->get_bits(), m_radius) : nullptr), 87 | m_outfacts_ptr(m_splitthr <= 0 ? dyft_factors::get_outfacts(m_database->get_bits(), m_radius) : nullptr), 88 | // m_fanouts(m_sigma, NOT_FOUND), 89 | m_depth_beg(depth_beg), m_depth_end(depth_end), m_in_weight(in_weight) { 90 | m_rootptr = m_array_256.make_node(); // make root 91 | m_states.reserve(1 << 10); 92 | m_edges.reserve(256); 93 | } 94 | 95 | uint32_t append() override { 96 | ABORT_IF_LE(m_database->get_size(), m_ids); 97 | 98 | const uint32_t new_id = m_ids++; 99 | const vint_type* vcode = m_database->access(new_id); 100 | 101 | // int npos = 0, depth = m_depth_beg; 102 | int depth = m_depth_beg; 103 | mart_cursor mc{UINT32_MAX, make_mart_nullptr(), m_rootptr}; 104 | 105 | const uint32_t new_lpos = m_idseqs.size(); 106 | const mart_pointer new_lptr{new_lpos, MART_LEAF}; 107 | 108 | while (true) { 109 | DEBUG_ABORT_IF_LE(m_depth_end, depth); 110 | const int c = vcode_tools::get_int(vcode, depth++, get_bits()); 111 | 112 | mart_insert_flags iflag; 113 | 114 | switch (mc.nptr.ntype) { 115 | case MART_NODE_4: { 116 | iflag = m_array_4.insert_ptr(mc, c, new_lptr); 117 | if (iflag == MART_NEEDED_TO_EXPAND) { 118 | iflag = expand(mc, c, new_lptr); 119 | ABORT_IF_NE(iflag, MART_INSERTED); 120 | } 121 | break; 122 | } 123 | case MART_NODE_16: { 124 | iflag = m_array_16.insert_ptr(mc, c, new_lptr); 125 | if (iflag == MART_NEEDED_TO_EXPAND) { 126 | iflag = expand(mc, c, new_lptr); 127 | ABORT_IF_NE(iflag, MART_INSERTED); 128 | } 129 | break; 130 | } 131 | case MART_NODE_64: { // instead of MART_NODE_48 132 | iflag = m_array_48.insert_ptr(mc, c, new_lptr); 133 | if (iflag == MART_NEEDED_TO_EXPAND) { 134 | iflag = expand(mc, c, new_lptr); 135 | ABORT_IF_NE(iflag, MART_INSERTED); 136 | } 137 | break; 138 | } 139 | case MART_NODE_256: { 140 | iflag = m_array_256.insert_ptr(mc, c, new_lptr); 141 | break; 142 | } 143 | default: { 144 | ABORT_IF(true); // should not come 145 | break; 146 | } 147 | } 148 | 149 | if (iflag == MART_INSERTED) { 150 | ABORT_IF_NE(mc.nptr.nid, new_lpos); 151 | ABORT_IF_NE(mc.nptr.ntype, MART_LEAF); 152 | m_idseqs.push_back(); 153 | break; 154 | } 155 | 156 | if (mc.nptr.ntype == MART_LEAF) { 157 | break; 158 | } 159 | } 160 | 161 | const uint32_t lpos = mc.nptr.nid; 162 | m_idseqs.insert(lpos, new_id); 163 | 164 | if (m_outfacts_ptr) { 165 | m_trie_cost += m_outfacts_ptr[depth - m_depth_beg - 1]; 166 | } 167 | 168 | if (depth == m_depth_end) { 169 | return new_id; 170 | } 171 | 172 | const int cnt = int(m_idseqs.group_size(lpos)); 173 | 174 | if (m_splitthrs_ptr == nullptr) { 175 | if (m_splitthr <= cnt) { 176 | split_node(mc, depth); 177 | } 178 | } else { 179 | if (m_splitthrs_ptr[depth - 1] * m_in_weight <= cnt) { 180 | split_node(mc, depth); 181 | } 182 | } 183 | return new_id; 184 | } 185 | 186 | void range_search(const vint_type* vcode, const std::function& fn) override { 187 | return selects_ls() ? ls_search(vcode, fn) : trie_search(vcode, fn); 188 | } 189 | 190 | void trie_search(const vint_type* vcode, const std::function& fn) { 191 | m_states.clear(); 192 | m_states.push_back(state_type{m_rootptr, m_depth_beg, 0}); 193 | 194 | int code[N]; 195 | for (int i = m_depth_beg; i < m_depth_end; i++) { 196 | code[i] = vcode_tools::get_int(vcode, i, get_bits()); 197 | } 198 | 199 | while (!m_states.empty()) { 200 | state_type s = m_states.back(); 201 | m_states.pop_back(); 202 | 203 | if (s.dist < m_radius) { 204 | const uint8_t c = code[s.depth]; 205 | m_edges.clear(); 206 | 207 | switch (s.nptr.ntype) { 208 | case MART_NODE_4: { 209 | m_s4.set(&m_array_4, s.nptr); 210 | m_s4.scan(m_edges); 211 | break; 212 | } 213 | case MART_NODE_16: { 214 | m_s16.set(&m_array_16, s.nptr); 215 | m_s16.scan(m_edges); 216 | break; 217 | } 218 | case MART_NODE_64: { 219 | m_s48.set(&m_array_48, s.nptr); 220 | m_s48.scan(m_edges); 221 | break; 222 | } 223 | case MART_NODE_256: { 224 | m_s256.set(&m_array_256, s.nptr); 225 | m_s256.scan(m_edges); 226 | break; 227 | } 228 | default: { 229 | ABORT_IF(true); // should not come 230 | break; 231 | } 232 | } 233 | 234 | for (const mart_edge& e : m_edges) { 235 | if (e.nptr.ntype == MART_LEAF) { // leaf 236 | const uint32_t lpos = e.nptr.nid; 237 | auto [bptr, eptr] = m_idseqs.access(lpos); 238 | for (auto it = bptr; it != eptr; ++it) { 239 | fn(*it); 240 | } 241 | } else { // internal 242 | if (e.c == c) { 243 | m_states.push_back(state_type{e.nptr, s.depth + 1, s.dist}); 244 | } else { 245 | m_states.push_back(state_type{e.nptr, s.depth + 1, s.dist + 1}); 246 | } 247 | } 248 | } 249 | } else { // exact match 250 | while (true) { 251 | const uint8_t c = code[s.depth]; 252 | 253 | switch (s.nptr.ntype) { 254 | case MART_NODE_4: { 255 | s.nptr = m_array_4.find_child(s.nptr, c); 256 | break; 257 | } 258 | case MART_NODE_16: { 259 | s.nptr = m_array_16.find_child(s.nptr, c); 260 | break; 261 | } 262 | case MART_NODE_64: { 263 | s.nptr = m_array_48.find_child(s.nptr, c); 264 | break; 265 | } 266 | case MART_NODE_256: { 267 | s.nptr = m_array_256.find_child(s.nptr, c); 268 | break; 269 | } 270 | default: { 271 | ABORT_IF(true); // should not come 272 | break; 273 | } 274 | } 275 | 276 | if (s.nptr.nid == MART_NILID) { 277 | break; 278 | } 279 | 280 | if (s.nptr.ntype == MART_LEAF) { 281 | const uint32_t lpos = s.nptr.nid; 282 | auto [bptr, eptr] = m_idseqs.access(lpos); 283 | for (auto it = bptr; it != eptr; ++it) { 284 | fn(*it); 285 | } 286 | break; 287 | } 288 | 289 | s.depth += 1; 290 | } 291 | } 292 | } 293 | } 294 | 295 | void ls_search(const vint_type* vcode, const std::function& fn) { 296 | for (uint32_t id = 0; id < get_size(); id++) { 297 | fn(id); 298 | } 299 | } 300 | 301 | uint32_t get_size() const override { 302 | return m_ids; 303 | } 304 | uint32_t get_leaves() const override { 305 | return m_idseqs.size(); 306 | } 307 | int get_bits() const override { 308 | return m_database->get_bits(); 309 | } 310 | 311 | double get_ls_cost() const override { 312 | return static_cast(m_ids) * get_bits(); 313 | } 314 | double get_trie_cost() const override { 315 | return m_trie_cost; 316 | } 317 | 318 | bool selects_ls() const override { 319 | return m_infacts_ptr and get_ls_cost() < get_trie_cost(); 320 | } 321 | 322 | size_t get_split_count() override { 323 | return m_split_count; 324 | } 325 | 326 | void innode_stats() override { 327 | tfm::printfln("## innode_stats ##"); 328 | tfm::printfln("- num node 4: %d (%d)", m_array_4.num_nodes(), m_array_4.num_emp_nodes()); 329 | tfm::printfln("- num node 16: %d (%d)", m_array_16.num_nodes(), m_array_16.num_emp_nodes()); 330 | tfm::printfln("- num node 48: %d (%d)", m_array_48.num_nodes(), m_array_48.num_emp_nodes()); 331 | tfm::printfln("- num node 256: %d (%d)", m_array_256.num_nodes(), m_array_256.num_emp_nodes()); 332 | } 333 | void population_stats() override {} 334 | 335 | private: 336 | mart_insert_flags expand(mart_cursor& mc, uint8_t c, mart_pointer new_ptr) { 337 | m_edges.clear(); 338 | 339 | switch (mc.nptr.ntype) { 340 | case MART_NODE_4: { 341 | m_array_4.extract(mc, m_edges); 342 | mc.nptr = m_array_16.make_node(m_edges); 343 | update_srcptr(mc); 344 | return m_array_16.append_ptr(mc, c, new_ptr); 345 | } 346 | case MART_NODE_16: { 347 | m_array_16.extract(mc, m_edges); 348 | mc.nptr = m_array_48.make_node(m_edges); 349 | update_srcptr(mc); 350 | return m_array_48.append_ptr(mc, c, new_ptr); 351 | } 352 | case MART_NODE_64: { // instead of MART_NODE_48 353 | m_array_48.extract(mc, m_edges); 354 | mc.nptr = m_array_256.make_node(m_edges); 355 | update_srcptr(mc); 356 | return m_array_256.insert_ptr(mc, c, new_ptr); 357 | } 358 | default: { 359 | ABORT_IF(true); // should not come 360 | break; 361 | } 362 | } 363 | } 364 | 365 | void update_srcptr(const mart_cursor& mc) { 366 | switch (mc.pptr.ntype) { 367 | case MART_NODE_4: { 368 | m_array_4.update_srcptr(mc); 369 | break; 370 | } 371 | case MART_NODE_16: { 372 | m_array_16.update_srcptr(mc); 373 | break; 374 | } 375 | case MART_NODE_64: { // instead of MART_NODE_48 376 | m_array_48.update_srcptr(mc); 377 | break; 378 | } 379 | case MART_NODE_256: { 380 | m_array_256.update_srcptr(mc); 381 | break; 382 | } 383 | default: { 384 | ABORT_IF(true); // should not come 385 | break; 386 | } 387 | } 388 | } 389 | 390 | void split_node(mart_cursor& mc, const int depth) { 391 | ABORT_IF_LE(depth, m_depth_beg); 392 | ABORT_IF_NE(mc.nptr.ntype, MART_LEAF); 393 | 394 | m_split_count += 1; 395 | 396 | const uint32_t lpos = mc.nptr.nid; 397 | ABORT_IF_LE(m_idseqs.size(), lpos); 398 | 399 | // update cost 400 | if (m_infacts_ptr) { 401 | const int _depth = depth - m_depth_beg - 1; 402 | const int g_size = m_idseqs.group_size(lpos); 403 | 404 | m_trie_cost -= g_size * m_outfacts_ptr[_depth]; 405 | m_trie_cost += (m_infacts_ptr[_depth] * m_in_weight) + (g_size * m_outfacts_ptr[_depth + 1]); 406 | } 407 | 408 | std::vector> idbufs(m_sigma); 409 | { 410 | const auto ids = m_idseqs.extract(lpos); 411 | for (uint32_t id : ids) { 412 | const vint_type* vcode = m_database->access(id); 413 | const int c = vcode_tools::get_int(vcode, depth, get_bits()); 414 | idbufs[c].push_back(id); 415 | } 416 | } 417 | 418 | m_edges.clear(); 419 | 420 | for (int c = 0; c < m_sigma; c++) { 421 | if (idbufs[c].empty()) { 422 | continue; 423 | } 424 | if (m_edges.empty()) { 425 | const mart_pointer new_lptr = mart_pointer{lpos, MART_LEAF}; 426 | m_edges.push_back(mart_edge{uint8_t(c), new_lptr}); 427 | m_idseqs.insert(lpos, idbufs[c]); 428 | } else { 429 | const mart_pointer new_lptr = mart_pointer{uint32_t(m_idseqs.size()), MART_LEAF}; 430 | m_edges.push_back(mart_edge{uint8_t(c), new_lptr}); 431 | m_idseqs.push_back(idbufs[c]); 432 | } 433 | } 434 | 435 | if (m_edges.size() <= 4) { 436 | mc.nptr = m_array_4.make_node(m_edges); 437 | } else if (m_edges.size() <= 16) { 438 | mc.nptr = m_array_16.make_node(m_edges); 439 | } else if (m_edges.size() <= 48) { 440 | mc.nptr = m_array_48.make_node(m_edges); 441 | } else { 442 | mc.nptr = m_array_256.make_node(m_edges); 443 | } 444 | 445 | update_srcptr(mc); 446 | } 447 | }; 448 | 449 | } // namespace dyft 450 | -------------------------------------------------------------------------------- /include/bit_tools.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #ifdef __SSE4_2__ 7 | #include 8 | #endif 9 | 10 | namespace dyft { 11 | 12 | struct bit_tools { 13 | static const uint8_t POPCNT_TABLE[256]; 14 | 15 | static int popcnt(uint8_t x) { 16 | return POPCNT_TABLE[x]; 17 | } 18 | static int popcnt(uint16_t x) { 19 | return POPCNT_TABLE[x & UINT8_MAX] + POPCNT_TABLE[x >> 8]; 20 | } 21 | static int popcnt(uint32_t x) { 22 | #ifdef __SSE4_2__ 23 | return static_cast(__builtin_popcount(x)); 24 | #else 25 | x = x - ((x >> 1) & 0x55555555); 26 | x = (x & 0x33333333) + ((x >> 2) & 0x33333333); 27 | return (0x10101010 * x >> 28) + (0x01010101 * x >> 28); 28 | #endif 29 | } 30 | static int popcnt(uint64_t x) { 31 | #ifdef __SSE4_2__ 32 | return static_cast(__builtin_popcountll(x)); 33 | #else 34 | x = x - ((x >> 1) & 0x5555555555555555ull); 35 | x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull); 36 | x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0full; 37 | return (0x0101010101010101ull * x >> 56); 38 | #endif 39 | } 40 | 41 | static void set_bit(uint8_t& x, int i, bool bit = true) { 42 | assert(0 <= i and i <= 8); 43 | if (bit) { 44 | x |= (1U << i); 45 | } else { 46 | x &= ~(1U << i); 47 | } 48 | } 49 | static void set_bit(uint16_t& x, int i, bool bit = true) { 50 | assert(0 <= i and i <= 16); 51 | if (bit) { 52 | x |= (1U << i); 53 | } else { 54 | x &= ~(1U << i); 55 | } 56 | } 57 | static void set_bit(uint32_t& x, int i, bool bit = true) { 58 | assert(0 <= i and i <= 32); 59 | if (bit) { 60 | x |= (1U << i); 61 | } else { 62 | x &= ~(1U << i); 63 | } 64 | } 65 | static void set_bit(uint64_t& x, int i, bool bit = true) { 66 | assert(0 <= i and i <= 64); 67 | if (bit) { 68 | x |= (1ULL << i); 69 | } else { 70 | x &= ~(1ULL << i); 71 | } 72 | } 73 | 74 | static bool get_bit(uint8_t x, int i) { 75 | assert(0 <= i and i <= 8); 76 | return (x & (1U << i)) != 0; 77 | } 78 | static bool get_bit(uint16_t x, int i) { 79 | assert(0 <= i and i <= 16); 80 | return (x & (1U << i)) != 0; 81 | } 82 | static bool get_bit(uint32_t x, int i) { 83 | assert(0 <= i and i <= 32); 84 | return (x & (1U << i)) != 0; 85 | } 86 | static bool get_bit(uint64_t x, int i) { 87 | assert(0 <= i and i <= 64); 88 | return (x & (1ULL << i)) != 0; 89 | } 90 | }; 91 | 92 | const uint8_t bit_tools::POPCNT_TABLE[256] = { 93 | 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 94 | 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 95 | 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 96 | 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 97 | 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 98 | 5, 5, 6, 5, 6, 6, 7, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 99 | 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8}; 100 | 101 | } // namespace dyft 102 | -------------------------------------------------------------------------------- /include/cmd_line_parser/parser.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace cmd_line_parser { 12 | 13 | struct parser { 14 | inline static const std::string empty = ""; 15 | 16 | parser(int argc, char** argv) : m_argc(argc), m_argv(argv), m_required(0) {} 17 | 18 | struct cmd { 19 | std::string shorthand, value, descr; 20 | bool is_boolean; 21 | }; 22 | 23 | bool parse() { 24 | if (size_t(m_argc - 1) < m_required) return abort(); 25 | size_t k = 0; 26 | for (int i = 1; i != m_argc; ++i, ++k) { 27 | std::string parsed(m_argv[i]); 28 | if (parsed == "-h" or parsed == "--help") return abort(); 29 | size_t id = k; 30 | bool is_optional = id >= m_required; 31 | if (is_optional) { 32 | auto it = m_shorthands.find(parsed); 33 | if (it == m_shorthands.end()) { 34 | std::cerr << "== error: shorthand '" + parsed + "' not found" << std::endl; 35 | return abort(); 36 | } 37 | id = (*it).second; 38 | } 39 | assert(id < m_names.size()); 40 | auto const& name = m_names[id]; 41 | auto& c = m_cmds[name]; 42 | if (is_optional) { 43 | if (c.is_boolean) { 44 | parsed = "true"; 45 | } else { 46 | ++i; 47 | if (i == m_argc) return abort(); 48 | parsed = m_argv[i]; 49 | } 50 | } 51 | c.value = parsed; 52 | } 53 | return true; 54 | } 55 | 56 | void help() const { 57 | std::cerr << "Usage: \e[1m" << m_argv[0] << "\e[0m [-h,--help]"; 58 | auto print = [this](bool with_description) { 59 | for (size_t i = 0; i != m_names.size(); ++i) { 60 | auto const& c = m_cmds.at(m_names[i]); 61 | bool is_optional = i >= m_required; 62 | if (is_optional) std::cerr << " [\e[1m" << c.shorthand << "\e[0m"; 63 | if (!c.is_boolean) std::cerr << " \e[4m" << m_names[i] << "\e[0m"; 64 | if (is_optional) std::cerr << "]"; 65 | if (with_description) std::cerr << "\n\t" << c.descr << "\n"; 66 | } 67 | }; 68 | print(false); 69 | std::cerr << "\n\n"; 70 | print(true); 71 | std::cerr << " [-h,--help]\n\tPrint this help text and silently exits." << std::endl; 72 | } 73 | 74 | bool add(std::string const& name, std::string const& descr) { 75 | bool ret = m_cmds.emplace(name, cmd{empty, empty, descr, false}).second; 76 | if (ret) { 77 | m_names.push_back(name); 78 | m_required += 1; 79 | } 80 | return ret; 81 | } 82 | 83 | bool add(std::string const& name, std::string const& descr, std::string const& shorthand, bool is_boolean = true) { 84 | bool ret = m_cmds.emplace(name, cmd{shorthand, is_boolean ? "false" : empty, descr, is_boolean}).second; 85 | if (ret) { 86 | m_names.push_back(name); 87 | m_shorthands.emplace(shorthand, m_names.size() - 1); 88 | } 89 | return ret; 90 | } 91 | 92 | template 93 | T get(std::string const& name) const { 94 | auto it = m_cmds.find(name); 95 | if (it == m_cmds.end()) { 96 | throw std::runtime_error("error: '" + name + "' not found"); 97 | } 98 | auto const& value = (*it).second.value; 99 | return parse(value); 100 | } 101 | 102 | // added by Kampersanda 103 | template 104 | T get(std::string const& name, const T& default_value) const { 105 | return parsed(name) ? get(name) : default_value; 106 | } 107 | 108 | bool parsed(std::string const& name) const { 109 | auto it = m_cmds.find(name); 110 | if (it == m_cmds.end() or (*it).second.value == empty) return false; 111 | return true; 112 | } 113 | 114 | template 115 | T parse(std::string const& value) const { 116 | if constexpr (std::is_same::value) { 117 | return value; 118 | } else if constexpr (std::is_same::value or std::is_same::value or 119 | std::is_same::value) { 120 | return value.front(); 121 | } else if constexpr (std::is_same::value or std::is_same::value or 122 | std::is_same::value or std::is_same::value) { 123 | return std::atoi(value.c_str()); 124 | } else if constexpr (std::is_same::value or std::is_same::value or 125 | std::is_same::value or std::is_same::value) { 126 | return std::atoll(value.c_str()); 127 | } else if constexpr (std::is_same::value or std::is_same::value or 128 | std::is_same::value) { 129 | return std::atof(value.c_str()); 130 | } else if constexpr (std::is_same::value) { 131 | std::istringstream stream(value); 132 | bool ret; 133 | if (value == "true" or value == "false") { 134 | stream >> std::boolalpha >> ret; 135 | } else { 136 | stream >> std::noboolalpha >> ret; 137 | } 138 | return ret; 139 | } 140 | assert(false); 141 | __builtin_unreachable(); 142 | } 143 | 144 | private: 145 | int m_argc; 146 | char** m_argv; 147 | size_t m_required; 148 | std::unordered_map m_cmds; 149 | std::unordered_map m_shorthands; 150 | std::vector m_names; 151 | 152 | bool abort() const { 153 | help(); 154 | return false; 155 | } 156 | }; 157 | 158 | } // namespace cmd_line_parser -------------------------------------------------------------------------------- /include/dyft_interface.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "vcode_array.hpp" 4 | 5 | namespace dyft { 6 | 7 | template 8 | class dyft_interface { 9 | public: 10 | using vint_type = typename vcode_tools::vint_type; 11 | 12 | virtual ~dyft_interface() {} 13 | 14 | virtual uint32_t append() = 0; 15 | virtual void range_search(const vint_type* vcode, const std::function& fn) = 0; 16 | 17 | virtual uint32_t get_size() const = 0; 18 | virtual uint32_t get_leaves() const = 0; 19 | virtual int get_bits() const = 0; 20 | 21 | virtual bool selects_ls() const = 0; 22 | virtual double get_ls_cost() const = 0; 23 | virtual double get_trie_cost() const = 0; 24 | 25 | virtual size_t get_split_count() = 0; 26 | 27 | virtual void innode_stats() = 0; 28 | virtual void population_stats() = 0; 29 | }; 30 | 31 | } // namespace dyft 32 | -------------------------------------------------------------------------------- /include/gv_index.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "vcode_array.hpp" 8 | 9 | namespace dyft { 10 | 11 | // A straightforward implementation of multi-index hashing for integer sketches, following the idea 12 | // - Gog and Venturini. Fast and compact Hamming distance index, SIGIR16 13 | template 14 | class gv_index { 15 | public: 16 | using vint_type = typename vcode_tools::vint_type; 17 | using table_type = std::unordered_map>; 18 | 19 | private: 20 | const vcode_array* m_database = nullptr; 21 | const int m_sigma = 0; 22 | const int m_radius = 0; 23 | const int m_blocks = 0; 24 | 25 | std::vector m_tables; 26 | std::vector m_begs; 27 | uint32_t m_ids = 0; 28 | 29 | public: 30 | gv_index(const vcode_array* database, int radius) 31 | : m_database(database), m_sigma(1 << m_database->get_bits()), m_radius(radius), m_blocks(radius / 2 + 1), 32 | m_tables(m_blocks), m_begs(m_blocks + 1) { 33 | m_begs[0] = 0; 34 | for (int b = 0; b < m_blocks; b++) { 35 | const int m = (b + N) / m_blocks; 36 | m_begs[b + 1] = m_begs[b] + m; 37 | } 38 | ABORT_IF_NE(m_begs[m_blocks], N); 39 | } 40 | 41 | uint32_t append() { 42 | const uint32_t new_id = m_ids++; 43 | const vint_type* vcode = m_database->access(new_id); 44 | 45 | int code[N]; 46 | for (int i = 0; i < N; i++) { 47 | code[i] = vcode_tools::get_int(vcode, i, get_bits()); 48 | } 49 | for (int b = 0; b < m_blocks; b++) { 50 | insert(code, b, new_id); 51 | } 52 | 53 | return new_id; 54 | } 55 | 56 | void range_search(const vint_type* vcode, const std::function& fn) { 57 | std::vector cands; 58 | cands.reserve(1U << 10); 59 | 60 | int code[N]; 61 | for (int i = 0; i < N; i++) { 62 | code[i] = vcode_tools::get_int(vcode, i, get_bits()); 63 | } 64 | for (int b = 0; b < m_blocks; b++) { 65 | search(code, b, [&](uint32_t id) { cands.push_back(id); }); 66 | } 67 | 68 | std::sort(cands.begin(), cands.end()); 69 | cands.erase(std::unique(cands.begin(), cands.end()), cands.end()); 70 | 71 | for (uint32_t id : cands) { 72 | fn(id); 73 | } 74 | } 75 | 76 | uint32_t get_size() const { 77 | return m_ids; 78 | } 79 | int get_bits() const { 80 | return m_database->get_bits(); 81 | } 82 | 83 | private: 84 | static uint64_t fnv1a(const int* str, const int n) { 85 | static const uint64_t init = uint64_t((sizeof(uint64_t) == 8) ? 0xcbf29ce484222325ULL : 0x811c9dc5ULL); 86 | static const uint64_t multiplier = uint64_t((sizeof(uint64_t) == 8) ? 0x100000001b3ULL : 0x1000193ULL); 87 | 88 | uint64_t h = init; 89 | for (int i = 0; i < n; i++) { 90 | h ^= uint64_t(str[i]); 91 | h *= multiplier; 92 | } 93 | return h; 94 | } 95 | 96 | void insert(const int* code, const int bpos, const uint32_t new_id) { 97 | const int beg = m_begs[bpos]; 98 | const int len = m_begs[bpos + 1] - m_begs[bpos]; 99 | 100 | code += beg; 101 | auto& table = m_tables[bpos]; 102 | 103 | const uint64_t h = fnv1a(code, len); 104 | 105 | auto it = table.find(h); 106 | if (it != table.end()) { // found? 107 | it->second.push_back(new_id); 108 | } else { 109 | table.insert(std::make_pair(h, std::vector{new_id})); 110 | } 111 | } 112 | 113 | void search(const int* code, const int bpos, const std::function& fn) { 114 | const int beg = m_begs[bpos]; 115 | const int len = m_begs[bpos + 1] - m_begs[bpos]; 116 | 117 | code += beg; 118 | const auto& table = m_tables[bpos]; 119 | 120 | // Exact Search 121 | { 122 | const uint64_t h = fnv1a(code, len); 123 | auto it = table.find(h); 124 | if (it != table.end()) { // found? 125 | for (uint32_t id : it->second) { 126 | fn(id); 127 | } 128 | } 129 | } 130 | 131 | // 1-err searches 132 | int sig[N]; 133 | for (int i = 0; i < len; i++) { 134 | std::copy(code, code + len, sig); 135 | for (int j = 1; j < m_sigma; j++) { 136 | sig[i] = (sig[i] + 1) % m_sigma; 137 | const uint64_t h = fnv1a(sig, len); 138 | auto it = table.find(h); 139 | if (it != table.end()) { // found? 140 | for (uint32_t id : it->second) { 141 | fn(id); 142 | } 143 | } 144 | } 145 | } 146 | } 147 | }; 148 | 149 | } // namespace dyft 150 | -------------------------------------------------------------------------------- /include/hms1dv_index.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "vcode_array.hpp" 8 | 9 | namespace dyft { 10 | 11 | // A straightforward implementation of dynamic HmSearch 1-del-var (HSD), described in the paper 12 | // - Zhang et al. HmSearch: An efficient Hamming distance query processing algorithm, SSDBM2013 13 | template 14 | class hms1dv_index { 15 | public: 16 | using vint_type = typename vcode_tools::vint_type; 17 | using table_type = std::unordered_map>; 18 | 19 | private: 20 | const vcode_array* m_database = nullptr; 21 | const int m_sigma = 0; // deletion marker also 22 | const int m_radius = 0; 23 | const int m_blocks = 0; 24 | 25 | std::vector m_tables; 26 | std::vector m_begs; 27 | uint32_t m_ids = 0; 28 | 29 | std::unordered_map m_match_map; 30 | std::unordered_map> m_cand_map; 31 | 32 | public: 33 | hms1dv_index(const vcode_array* database, int radius) 34 | : m_database(database), m_sigma(1 << m_database->get_bits()), m_radius(radius), m_blocks((radius + 3) / 2), 35 | m_tables(m_blocks), m_begs(m_blocks + 1) { 36 | m_begs[0] = 0; 37 | for (int b = 0; b < m_blocks; b++) { 38 | const int m = (b + N) / m_blocks; 39 | m_begs[b + 1] = m_begs[b] + m; 40 | } 41 | ABORT_IF_NE(m_begs[m_blocks], N); 42 | } 43 | 44 | uint32_t append() { 45 | const uint32_t new_id = m_ids++; 46 | const vint_type* vcode = m_database->access(new_id); 47 | 48 | int code[N]; 49 | for (int i = 0; i < N; i++) { 50 | code[i] = vcode_tools::get_int(vcode, i, get_bits()); 51 | } 52 | for (int b = 0; b < m_blocks; b++) { 53 | insert(code, b, new_id); 54 | } 55 | 56 | return new_id; 57 | } 58 | 59 | void range_search(const vint_type* vcode, const std::function& fn) { 60 | int code[N]; 61 | for (int i = 0; i < N; i++) { 62 | code[i] = vcode_tools::get_int(vcode, i, get_bits()); 63 | } 64 | 65 | m_cand_map.clear(); 66 | 67 | for (int b = 0; b < m_blocks; b++) { 68 | m_match_map.clear(); 69 | 70 | search(code, b, [&](uint32_t id) { 71 | auto it = m_match_map.find(id); 72 | if (it == m_match_map.end()) { 73 | m_match_map.insert(std::make_pair(id, 1U)); 74 | } else { 75 | it->second += 1; 76 | } 77 | }); 78 | 79 | for (const auto& kv : m_match_map) { 80 | if (kv.second > 2) { 81 | auto it = m_cand_map.find(kv.first); 82 | if (it != m_cand_map.end()) { 83 | it->second.push_back(0); 84 | } else { 85 | m_cand_map.insert(std::make_pair(kv.first, std::vector{0U})); 86 | } 87 | } else { 88 | auto it = m_cand_map.find(kv.first); 89 | if (it != m_cand_map.end()) { 90 | it->second.push_back(1); 91 | } else { 92 | m_cand_map.insert(std::make_pair(kv.first, std::vector{1U})); 93 | } 94 | } 95 | } 96 | } 97 | 98 | for (const auto& kv : m_cand_map) { 99 | uint32_t cand_id = kv.first; 100 | const std::vector& errors = kv.second; 101 | 102 | ABORT_IF_EQ(errors.size(), 0); 103 | 104 | // enhanced filter 105 | bool filtered = false; 106 | 107 | if (m_radius % 2 == 0) { 108 | if (errors.size() < 2) { // has less than two number 109 | if (errors[0] == 1) { 110 | filtered = true; 111 | } 112 | } 113 | } else { 114 | if (errors.size() < 3) { // has less than three number 115 | if (errors.size() == 1) { 116 | filtered = true; 117 | } else if (errors[0] == 1 and errors[1] == 1) { 118 | filtered = true; 119 | } 120 | } 121 | } 122 | 123 | if (!filtered) { 124 | fn(cand_id); 125 | } 126 | } 127 | } 128 | 129 | uint32_t get_size() const { 130 | return m_ids; 131 | } 132 | int get_bits() const { 133 | return m_database->get_bits(); 134 | } 135 | 136 | private: 137 | static uint64_t fnv1a(const int* str, const int n) { 138 | static const uint64_t init = uint64_t((sizeof(uint64_t) == 8) ? 0xcbf29ce484222325ULL : 0x811c9dc5ULL); 139 | static const uint64_t multiplier = uint64_t((sizeof(uint64_t) == 8) ? 0x100000001b3ULL : 0x1000193ULL); 140 | 141 | uint64_t h = init; 142 | for (int i = 0; i < n; i++) { 143 | h ^= uint64_t(str[i]); 144 | h *= multiplier; 145 | } 146 | return h; 147 | } 148 | 149 | void insert(const int* code, const int bpos, const uint32_t new_id) { 150 | const int beg = m_begs[bpos]; 151 | const int len = m_begs[bpos + 1] - m_begs[bpos]; 152 | 153 | code += beg; 154 | auto& table = m_tables[bpos]; 155 | 156 | int sig[N]; 157 | for (int i = 0; i < len; i++) { 158 | std::copy(code, code + len, sig); 159 | sig[i] = m_sigma; 160 | 161 | const uint64_t h = fnv1a(sig, len); 162 | 163 | auto it = table.find(h); 164 | if (it != table.end()) { // found? 165 | it->second.push_back(new_id); 166 | } else { 167 | table.insert(std::make_pair(h, std::vector{new_id})); 168 | } 169 | } 170 | } 171 | 172 | void search(const int* code, const int bpos, const std::function& fn) { 173 | const int beg = m_begs[bpos]; 174 | const int len = m_begs[bpos + 1] - m_begs[bpos]; 175 | 176 | code += beg; 177 | const auto& table = m_tables[bpos]; 178 | 179 | int sig[N]; 180 | for (int i = 0; i < len; i++) { 181 | std::copy(code, code + len, sig); 182 | sig[i] = m_sigma; 183 | 184 | const uint64_t h = fnv1a(sig, len); 185 | 186 | auto it = table.find(h); 187 | if (it != table.end()) { // found? 188 | for (uint32_t id : it->second) { 189 | fn(id); 190 | } 191 | } 192 | } 193 | } 194 | }; 195 | 196 | } // namespace dyft 197 | -------------------------------------------------------------------------------- /include/hms1v_index.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "vcode_array.hpp" 8 | 9 | namespace dyft { 10 | 11 | // A straightforward implementation of dynamic HmSearch 1-var (HSV), described in the paper 12 | // - Zhang et al. HmSearch: An efficient Hamming distance query processing algorithm, SSDBM2013 13 | template 14 | class hms1v_index { 15 | public: 16 | struct idvec_pair { 17 | std::vector exact_vec; 18 | std::vector variant_vec; 19 | }; 20 | using vint_type = typename vcode_tools::vint_type; 21 | using table_type = std::unordered_map; 22 | 23 | private: 24 | const vcode_array* m_database = nullptr; 25 | const int m_sigma = 0; // deletion marker also 26 | const int m_radius = 0; 27 | const int m_blocks = 0; 28 | 29 | std::vector m_tables; 30 | std::vector m_begs; 31 | uint32_t m_ids = 0; 32 | 33 | std::unordered_map> m_cand_map; 34 | 35 | public: 36 | hms1v_index(const vcode_array* database, int radius) 37 | : m_database(database), m_sigma(1 << m_database->get_bits()), m_radius(radius), m_blocks((radius + 3) / 2), 38 | m_tables(m_blocks), m_begs(m_blocks + 1) { 39 | m_begs[0] = 0; 40 | for (int b = 0; b < m_blocks; b++) { 41 | const int m = (b + N) / m_blocks; 42 | m_begs[b + 1] = m_begs[b] + m; 43 | } 44 | ABORT_IF_NE(m_begs[m_blocks], N); 45 | } 46 | 47 | uint32_t append() { 48 | const uint32_t new_id = m_ids++; 49 | const vint_type* vcode = m_database->access(new_id); 50 | 51 | int code[N]; 52 | for (int i = 0; i < N; i++) { 53 | code[i] = vcode_tools::get_int(vcode, i, get_bits()); 54 | } 55 | for (int b = 0; b < m_blocks; b++) { 56 | insert(code, b, new_id); 57 | } 58 | 59 | return new_id; 60 | } 61 | 62 | void range_search(const vint_type* vcode, const std::function& fn) { 63 | int code[N]; 64 | for (int i = 0; i < N; i++) { 65 | code[i] = vcode_tools::get_int(vcode, i, get_bits()); 66 | } 67 | 68 | m_cand_map.clear(); 69 | 70 | for (int b = 0; b < m_blocks; b++) { 71 | search(code, b, [&](uint32_t id, uint32_t errs) { 72 | auto it = m_cand_map.find(id); 73 | if (it != m_cand_map.end()) { 74 | it->second.push_back(errs); 75 | } else { 76 | m_cand_map.insert(std::make_pair(id, std::vector{errs})); 77 | } 78 | }); 79 | } 80 | 81 | for (const auto& kv : m_cand_map) { 82 | uint32_t cand_id = kv.first; 83 | const std::vector& errors = kv.second; 84 | 85 | ABORT_IF_EQ(errors.size(), 0); 86 | 87 | // enhanced filter 88 | bool filtered = false; 89 | 90 | if (m_radius % 2 == 0) { 91 | if (errors.size() < 2) { // has less than two number 92 | if (errors[0] == 1) { 93 | filtered = true; 94 | } 95 | } 96 | } else { 97 | if (errors.size() < 3) { // has less than three number 98 | if (errors.size() == 1) { 99 | filtered = true; 100 | } else if (errors[0] == 1 and errors[1] == 1) { 101 | filtered = true; 102 | } 103 | } 104 | } 105 | 106 | if (!filtered) { 107 | fn(cand_id); 108 | } 109 | } 110 | } 111 | 112 | uint32_t get_size() const { 113 | return m_ids; 114 | } 115 | int get_bits() const { 116 | return m_database->get_bits(); 117 | } 118 | 119 | private: 120 | static uint64_t fnv1a(const int* str, const int n) { 121 | static const uint64_t init = uint64_t((sizeof(uint64_t) == 8) ? 0xcbf29ce484222325ULL : 0x811c9dc5ULL); 122 | static const uint64_t multiplier = uint64_t((sizeof(uint64_t) == 8) ? 0x100000001b3ULL : 0x1000193ULL); 123 | 124 | uint64_t h = init; 125 | for (int i = 0; i < n; i++) { 126 | h ^= uint64_t(str[i]); 127 | h *= multiplier; 128 | } 129 | return h; 130 | } 131 | 132 | void insert(const int* code, const int bpos, const uint32_t new_id) { 133 | const int beg = m_begs[bpos]; 134 | const int len = m_begs[bpos + 1] - m_begs[bpos]; 135 | 136 | code += beg; 137 | auto& table = m_tables[bpos]; 138 | 139 | // exact one 140 | { 141 | const uint64_t h = fnv1a(code, len); 142 | 143 | auto it = table.find(h); 144 | if (it != table.end()) { // found? 145 | it->second.exact_vec.push_back(new_id); 146 | } else { 147 | table.insert(std::make_pair(h, idvec_pair{std::vector{new_id}, std::vector{}})); 148 | } 149 | } 150 | 151 | // 1-variant ones 152 | int sig[N]; 153 | for (int i = 0; i < len; i++) { 154 | std::copy(code, code + len, sig); 155 | 156 | for (int c = 0; c < m_sigma; c++) { 157 | if (c == code[i]) { 158 | continue; 159 | } 160 | 161 | sig[i] = c; 162 | const uint64_t h = fnv1a(sig, len); 163 | 164 | auto it = table.find(h); 165 | if (it != table.end()) { // found? 166 | it->second.variant_vec.push_back(new_id); 167 | } else { 168 | table.insert(std::make_pair(h, idvec_pair{std::vector{}, std::vector{new_id}})); 169 | } 170 | } 171 | } 172 | } 173 | 174 | void search(const int* code, const int bpos, const std::function& fn) { 175 | const int beg = m_begs[bpos]; 176 | const int len = m_begs[bpos + 1] - m_begs[bpos]; 177 | 178 | code += beg; 179 | const auto& table = m_tables[bpos]; 180 | 181 | const uint64_t h = fnv1a(code, len); 182 | 183 | auto it = table.find(h); 184 | if (it == table.end()) { // not found? 185 | return; 186 | } 187 | 188 | for (uint32_t id : it->second.exact_vec) { 189 | fn(id, 0); 190 | } 191 | for (uint32_t id : it->second.variant_vec) { 192 | fn(id, 1); 193 | } 194 | } 195 | }; 196 | 197 | } // namespace dyft 198 | -------------------------------------------------------------------------------- /include/io.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include // is_any_of 7 | #include 8 | #include 9 | 10 | #include "tinyformat/tinyformat.h" 11 | 12 | #include "abort_if.hpp" 13 | #include "splitmix.hpp" 14 | #include "vcode_array.hpp" 15 | 16 | namespace dyft { 17 | 18 | inline uint64_t get_filesize(const std::string& filepath) { 19 | boost::filesystem::path path(filepath); 20 | return boost::filesystem::file_size(path); 21 | } 22 | 23 | inline std::ifstream make_ifstream(const std::string& filepath) { 24 | std::ifstream ifs(filepath); 25 | ABORT_IF(!ifs); 26 | return ifs; 27 | } 28 | inline std::ofstream make_ofstream(const std::string& filepath) { 29 | std::ofstream ofs(filepath); 30 | ABORT_IF(!ofs); 31 | return ofs; 32 | } 33 | 34 | inline bool exists_path(const std::string path) { 35 | boost::filesystem::path p(path); 36 | return boost::filesystem::exists(p); 37 | } 38 | 39 | inline void make_directory(const std::string dir) { 40 | boost::filesystem::path path(dir); 41 | if (!boost::filesystem::exists(path)) { 42 | if (!boost::filesystem::create_directories(path)) { 43 | tfm::errorfln("unable to create output directory %s", dir); 44 | exit(1); 45 | } 46 | } 47 | } 48 | 49 | template 50 | std::unique_ptr> load_vcodes_from_bin(const std::string& path) { 51 | using vint_type = typename vcode_tools::vint_type; 52 | 53 | const uint64_t num_codes = get_filesize(path) / sizeof(uint64_t); 54 | std::vector vcodes(num_codes); 55 | 56 | auto ifs = make_ifstream(path); 57 | for (uint64_t i = 0; i < num_codes; i++) { 58 | uint64_t vc = 0; 59 | ifs.read(reinterpret_cast(&vc), sizeof(uint64_t)); 60 | vcodes[i] = static_cast(vc); 61 | } 62 | return std::make_unique>(std::move(vcodes), 1); 63 | } 64 | 65 | template 66 | std::unique_ptr> load_vcodes_from_bvecs(const std::string& path, int bits) { 67 | auto database = std::make_unique>(bits); 68 | 69 | std::vector code; 70 | for (auto ifs = make_ifstream(path);;) { 71 | uint32_t dim = 0; 72 | ifs.read(reinterpret_cast(&dim), sizeof(uint32_t)); 73 | if (ifs.eof()) { 74 | break; 75 | } 76 | ABORT_IF_LT(dim, N); 77 | 78 | code.resize(dim); 79 | ifs.read(reinterpret_cast(code.data()), sizeof(uint8_t) * dim); 80 | 81 | database->append(code.data()); 82 | } 83 | return database; 84 | } 85 | 86 | } // namespace dyft 87 | -------------------------------------------------------------------------------- /include/mart_array_dense.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #if defined(__AVX__) || defined(__AVX2__) 4 | #include 5 | #endif 6 | 7 | #include 8 | #include 9 | 10 | #include "abort_if.hpp" 11 | #include "mart_common.hpp" 12 | 13 | namespace dyft { 14 | 15 | template 16 | class mart_array_dense { 17 | static_assert(2 <= K and K < 255); 18 | 19 | public: 20 | static constexpr uint64_t PTR_SIZE = 5; 21 | static constexpr uint64_t IDXS_OFFSET = 1; 22 | static constexpr uint64_t PTRS_OFFSET = 1 + 256; 23 | 24 | static constexpr uint64_t BYTES = 1 + 256 + (K * PTR_SIZE); // header + 256 indexes + K ptrs 25 | static constexpr mart_node_types NTYPE = NodeType; 26 | 27 | static constexpr uint8_t NIL_IDX = UINT8_MAX; 28 | 29 | private: 30 | std::vector m_nodes; 31 | uint32_t m_head_nid = MART_NILID; 32 | uint32_t m_num_emps = 0; 33 | 34 | static mart_pointer get_martptr(const uint8_t* ptrs) { 35 | return {*reinterpret_cast(ptrs), mart_node_types(ptrs[4])}; 36 | } 37 | static void set_martptr(uint8_t* ptrs, mart_pointer new_ptr) { 38 | *reinterpret_cast(ptrs) = new_ptr.nid; 39 | ptrs[4] = uint8_t(new_ptr.ntype); 40 | } 41 | 42 | public: 43 | mart_array_dense() = default; 44 | 45 | mart_insert_flags insert_ptr(mart_cursor& mc, uint8_t c, mart_pointer new_ptr) { 46 | DEBUG_ABORT_IF_NE(mc.nptr.ntype, NTYPE); 47 | 48 | const uint64_t pos = mc.nptr.nid * BYTES; 49 | 50 | const uint32_t num = m_nodes[pos]; 51 | uint8_t* idxs = &m_nodes[pos + IDXS_OFFSET]; 52 | uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 53 | 54 | DEBUG_ABORT_IF_LT(K, num); 55 | 56 | if (idxs[c] != NIL_IDX) { // found? 57 | const uint32_t i = idxs[c]; 58 | mc = mart_cursor{i, mc.nptr, get_martptr(ptrs + (i * PTR_SIZE))}; 59 | return MART_FOUND; 60 | } 61 | 62 | // can insert 63 | if (num < K) { 64 | m_nodes[pos] += 1; 65 | idxs[c] = static_cast(num); 66 | set_martptr(ptrs + (num * PTR_SIZE), new_ptr); 67 | 68 | mc = mart_cursor{num, mc.nptr, new_ptr}; 69 | return MART_INSERTED; // inserted 70 | } 71 | 72 | return MART_NEEDED_TO_EXPAND; 73 | } 74 | 75 | mart_insert_flags append_ptr(mart_cursor& mc, uint8_t c, mart_pointer new_ptr) { 76 | DEBUG_ABORT_IF_NE(mc.nptr.ntype, NTYPE); 77 | 78 | const uint64_t pos = mc.nptr.nid * BYTES; 79 | 80 | const uint32_t num = m_nodes[pos]; 81 | uint8_t* idxs = &m_nodes[pos + IDXS_OFFSET]; 82 | uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 83 | 84 | DEBUG_ABORT_IF_LE(K, num); 85 | DEBUG_ABORT_IF_NE(idxs[c], NIL_IDX); 86 | 87 | // not searched 88 | m_nodes[pos] += 1; 89 | idxs[c] = static_cast(num); 90 | set_martptr(ptrs + (num * PTR_SIZE), new_ptr); 91 | 92 | // set src 93 | mc = mart_cursor{num, mc.nptr, new_ptr}; 94 | return MART_INSERTED; 95 | } 96 | 97 | void extract(const mart_cursor& mc, std::vector& edges) { 98 | DEBUG_ABORT_IF_NE(mc.nptr.ntype, NTYPE); 99 | 100 | const uint32_t nid = mc.nptr.nid; 101 | const uint64_t pos = nid * BYTES; 102 | 103 | const uint8_t* idxs = &m_nodes[pos + IDXS_OFFSET]; 104 | const uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 105 | 106 | for (uint32_t i = 0; i < 256; i++) { 107 | if (idxs[i] != NIL_IDX) { 108 | edges.push_back(mart_edge{uint8_t(i), get_martptr(ptrs + (idxs[i] * PTR_SIZE))}); 109 | } 110 | } 111 | 112 | // To make empty 113 | if (m_head_nid == MART_NILID) { 114 | prev_ref(nid) = nid; 115 | next_ref(nid) = nid; 116 | m_head_nid = nid; 117 | } else { 118 | const uint32_t prev_nid = prev_ref(m_head_nid); 119 | prev_ref(nid) = prev_nid; 120 | next_ref(nid) = m_head_nid; 121 | next_ref(prev_nid) = nid; 122 | prev_ref(m_head_nid) = nid; 123 | } 124 | 125 | m_num_emps += 1; 126 | } 127 | 128 | mart_pointer make_node() { 129 | // Reuse empty element 130 | if (m_head_nid != MART_NILID) { 131 | m_num_emps -= 1; 132 | 133 | const uint32_t new_nid = m_head_nid; 134 | const uint32_t prev_nid = prev_ref(m_head_nid); 135 | const uint32_t next_nid = next_ref(m_head_nid); 136 | 137 | if (next_nid == m_head_nid) { 138 | m_head_nid = MART_NILID; 139 | } else { 140 | m_head_nid = next_nid; 141 | prev_ref(next_nid) = prev_nid; 142 | next_ref(prev_nid) = next_nid; 143 | } 144 | 145 | const uint64_t new_pos = new_nid * BYTES; 146 | 147 | m_nodes[new_pos] = 0; // fanout 148 | uint8_t* idxs = &m_nodes[new_pos + IDXS_OFFSET]; 149 | std::fill(idxs, idxs + 256, NIL_IDX); 150 | 151 | return mart_pointer{new_nid, NTYPE}; 152 | } 153 | 154 | if (m_nodes.size() == 0) { 155 | m_nodes.reserve(BYTES); 156 | } else if (m_nodes.size() + BYTES > m_nodes.capacity()) { 157 | m_nodes.reserve(m_nodes.capacity() * 2); 158 | } 159 | 160 | const uint64_t new_pos = m_nodes.size(); 161 | const uint32_t new_nid = static_cast(new_pos / BYTES); 162 | m_nodes.resize(new_pos + BYTES); 163 | m_nodes[new_pos] = 0; // fanout 164 | 165 | uint8_t* idxs = &m_nodes[new_pos + IDXS_OFFSET]; 166 | std::fill(idxs, idxs + 256, NIL_IDX); 167 | 168 | return mart_pointer{new_nid, NTYPE}; 169 | } 170 | 171 | // returns new nid 172 | mart_pointer make_node(const std::vector& edges) { 173 | ABORT_IF_OUT(edges.size(), 0, K); 174 | 175 | const mart_pointer new_nptr = make_node(); 176 | 177 | const uint32_t new_nid = new_nptr.nid; 178 | const uint64_t new_pos = new_nid * BYTES; 179 | 180 | uint8_t* idxs = &m_nodes[new_pos + IDXS_OFFSET]; 181 | uint8_t* ptrs = &m_nodes[new_pos + PTRS_OFFSET]; 182 | 183 | m_nodes[new_pos] = static_cast(edges.size()); 184 | for (uint32_t i = 0; i < edges.size(); i++) { 185 | idxs[edges[i].c] = static_cast(i); 186 | set_martptr(ptrs + (i * PTR_SIZE), edges[i].nptr); 187 | } 188 | 189 | return new_nptr; 190 | } 191 | 192 | void update_srcptr(const mart_cursor& mc) { 193 | DEBUG_ABORT_IF_NE(mc.pptr.ntype, NTYPE); 194 | 195 | const uint64_t pos = mc.pptr.nid * BYTES; 196 | uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 197 | 198 | set_martptr(ptrs + (mc.offset * PTR_SIZE), mc.nptr); 199 | } 200 | 201 | mart_pointer find_child(mart_pointer nptr, uint8_t c) { 202 | DEBUG_ABORT_IF_NE(nptr.ntype, NTYPE); 203 | 204 | const uint32_t nid = nptr.nid; 205 | const uint64_t pos = nid * BYTES; 206 | 207 | const uint8_t* idxs = &m_nodes[pos + IDXS_OFFSET]; 208 | const uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 209 | 210 | if (idxs[c] != NIL_IDX) { // found? 211 | return get_martptr(ptrs + (idxs[c] * PTR_SIZE)); 212 | } 213 | return make_mart_nullptr(); 214 | } 215 | 216 | struct searcher { 217 | searcher() = default; 218 | 219 | searcher(const mart_array_dense* obj, mart_pointer nptr) { 220 | set(obj, nptr); 221 | } 222 | 223 | void set(const mart_array_dense* obj, mart_pointer nptr) { 224 | DEBUG_ABORT_IF_NE(nptr.ntype, NTYPE); 225 | 226 | const uint64_t pos = nptr.nid * BYTES; 227 | 228 | m_num = obj->m_nodes[pos]; 229 | m_idxs = &obj->m_nodes[pos + IDXS_OFFSET]; 230 | m_ptrs = &obj->m_nodes[pos + PTRS_OFFSET]; 231 | } 232 | 233 | mart_pointer find(uint8_t c) { 234 | if (m_idxs[c] != NIL_IDX) { // found? 235 | return get_martptr(m_ptrs + (m_idxs[c] * PTR_SIZE)); 236 | } 237 | return make_mart_nullptr(); 238 | } 239 | void scan(std::vector& edges) { 240 | for (uint32_t i = 0; i < 256; i++) { 241 | if (m_idxs[i] != NIL_IDX) { 242 | edges.push_back(mart_edge{uint8_t(i), get_martptr(m_ptrs + (m_idxs[i] * PTR_SIZE))}); 243 | } 244 | } 245 | } 246 | 247 | private: 248 | uint32_t m_num = 0; 249 | const uint8_t* m_idxs = nullptr; 250 | const uint8_t* m_ptrs = nullptr; 251 | }; 252 | 253 | void population_stats() { 254 | std::set emps = get_emps(); 255 | 256 | uint64_t cnts[K + 1]; 257 | std::fill(cnts, cnts + K + 1, 0); 258 | 259 | for (uint64_t i = 0; i < m_nodes.size(); i += BYTES) { 260 | if (emps.find(i) != emps.end()) { 261 | continue; 262 | } 263 | DEBUG_ABORT_IF_LT(K, m_nodes[i]); 264 | cnts[m_nodes[i]] += 1; 265 | } 266 | 267 | const double sum = std::accumulate(cnts, cnts + K + 1, 0.0); 268 | 269 | tfm::printfln("## mart_array_dense (K=%d) ##", K); 270 | for (uint32_t k = 0; k <= K; k++) { 271 | if (cnts[k] == 0) continue; 272 | tfm::printfln("- k=%d: %d (%g)", k, cnts[k], cnts[k] / sum); 273 | } 274 | } 275 | 276 | uint32_t num_nodes() { 277 | return (m_nodes.size() / BYTES) - m_num_emps; 278 | } 279 | uint32_t num_emp_nodes() { 280 | return m_num_emps; 281 | } 282 | 283 | void debug_emplist() { 284 | uint32_t emps = 0; 285 | if (m_head_nid != MART_NILID) { 286 | uint32_t emp_nid = m_head_nid; 287 | while (true) { 288 | emps += 1; 289 | uint32_t next_nid = next_ref(emp_nid); 290 | if (next_nid == m_head_nid) break; 291 | emp_nid = next_nid; 292 | } 293 | } 294 | ABORT_IF_NE(m_num_emps, emps); 295 | } 296 | 297 | std::set get_emps() { 298 | std::set emps; 299 | if (m_head_nid != MART_NILID) { 300 | uint32_t emp_nid = m_head_nid; 301 | while (true) { 302 | emps.insert(emp_nid * BYTES); 303 | uint32_t next_nid = next_ref(emp_nid); 304 | if (next_nid == m_head_nid) break; 305 | emp_nid = next_nid; 306 | } 307 | } 308 | return emps; 309 | } 310 | 311 | private: 312 | uint32_t& prev_ref(uint32_t nid) { 313 | return reinterpret_cast(&m_nodes[nid * BYTES])[0]; 314 | } 315 | uint32_t& next_ref(uint32_t nid) { 316 | return reinterpret_cast(&m_nodes[nid * BYTES])[1]; 317 | } 318 | }; 319 | 320 | } // namespace dyft 321 | -------------------------------------------------------------------------------- /include/mart_array_full.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "abort_if.hpp" 6 | #include "mart_common.hpp" 7 | 8 | namespace dyft { 9 | 10 | template 11 | class mart_array_full { 12 | public: 13 | static constexpr uint64_t PTR_SIZE = 5; 14 | static constexpr uint64_t PTRS_OFFSET = 1; 15 | 16 | static constexpr uint64_t BYTES = 1 + (256 * PTR_SIZE); // header + 256 ptrs 17 | static constexpr mart_node_types NTYPE = NodeType; 18 | 19 | private: 20 | std::vector m_nodes; 21 | 22 | static mart_pointer get_martptr(const uint8_t* ptrs) { 23 | return {*reinterpret_cast(ptrs), mart_node_types(ptrs[4])}; 24 | } 25 | static void set_martptr(uint8_t* ptrs, mart_pointer new_ptr) { 26 | *reinterpret_cast(ptrs) = new_ptr.nid; 27 | ptrs[4] = uint8_t(new_ptr.ntype); 28 | } 29 | 30 | public: 31 | mart_array_full() = default; 32 | 33 | mart_insert_flags insert_ptr(mart_cursor& mc, uint8_t c, mart_pointer new_ptr) { 34 | DEBUG_ABORT_IF_NE(mc.nptr.ntype, NTYPE); 35 | 36 | const uint64_t pos = mc.nptr.nid * BYTES; 37 | 38 | uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 39 | mart_pointer mptr = get_martptr(ptrs + (c * PTR_SIZE)); 40 | 41 | if (mptr.nid == MART_NILID) { 42 | // not found 43 | m_nodes[pos] += 1; 44 | set_martptr(ptrs + (c * PTR_SIZE), new_ptr); 45 | mc = mart_cursor{uint32_t(c), mc.nptr, new_ptr}; 46 | return MART_INSERTED; 47 | } 48 | 49 | // found 50 | mc = mart_cursor{uint32_t(c), mc.nptr, mptr}; 51 | return MART_FOUND; 52 | } 53 | 54 | // returns new node ptr 55 | mart_pointer make_node() { 56 | if (m_nodes.size() == 0) { 57 | m_nodes.reserve(BYTES); 58 | } else if (m_nodes.size() + BYTES > m_nodes.capacity()) { 59 | m_nodes.reserve(m_nodes.capacity() * 2); 60 | } 61 | 62 | const uint64_t new_pos = m_nodes.size(); 63 | const uint32_t new_nid = static_cast(new_pos / BYTES); 64 | m_nodes.resize(new_pos + BYTES); 65 | m_nodes[new_pos] = 0; // fanout 66 | 67 | uint8_t* ptrs = &m_nodes[new_pos + PTRS_OFFSET]; 68 | for (uint32_t i = 0; i < 256; i++) { 69 | set_martptr(ptrs + (i * PTR_SIZE), make_mart_nullptr()); 70 | } 71 | 72 | return mart_pointer{new_nid, NTYPE}; 73 | } 74 | 75 | // returns new nid 76 | mart_pointer make_node(const std::vector& edges) { 77 | const mart_pointer new_nptr = make_node(); 78 | 79 | const uint32_t new_nid = new_nptr.nid; 80 | const uint64_t new_pos = new_nid * BYTES; 81 | 82 | m_nodes[new_pos] = static_cast(edges.size()); 83 | uint8_t* ptrs = &m_nodes[new_pos + PTRS_OFFSET]; 84 | 85 | for (const mart_edge& e : edges) { 86 | set_martptr(ptrs + (e.c * PTR_SIZE), e.nptr); 87 | } 88 | return new_nptr; 89 | } 90 | 91 | void update_srcptr(const mart_cursor& mc) { 92 | DEBUG_ABORT_IF_NE(mc.pptr.ntype, NTYPE); 93 | 94 | const uint64_t pos = mc.pptr.nid * BYTES; 95 | uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 96 | 97 | set_martptr(ptrs + (mc.offset * PTR_SIZE), mc.nptr); 98 | } 99 | 100 | mart_pointer find_child(mart_pointer nptr, uint8_t c) { 101 | DEBUG_ABORT_IF_NE(nptr.ntype, NTYPE); 102 | 103 | const uint64_t pos = nptr.nid * BYTES; 104 | const uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 105 | 106 | return get_martptr(ptrs + (c * PTR_SIZE)); 107 | } 108 | 109 | struct searcher { 110 | searcher() = default; 111 | 112 | searcher(const mart_array_full* obj, mart_pointer nptr) { 113 | set(obj, nptr); 114 | } 115 | 116 | void set(const mart_array_full* obj, mart_pointer nptr) { 117 | DEBUG_ABORT_IF_NE(nptr.ntype, NTYPE); 118 | 119 | const uint64_t pos = nptr.nid * BYTES; 120 | m_ptrs = &obj->m_nodes[pos + PTRS_OFFSET]; 121 | } 122 | 123 | mart_pointer find(uint8_t c) { 124 | return get_martptr(m_ptrs + (c * PTR_SIZE)); 125 | } 126 | void scan(std::vector& edges) { 127 | for (uint32_t i = 0; i < 256; i++) { 128 | mart_pointer ptr = get_martptr(m_ptrs + (i * PTR_SIZE)); 129 | if (ptr.nid != MART_NILID) { 130 | edges.push_back(mart_edge{uint8_t(i), ptr}); 131 | } 132 | } 133 | } 134 | 135 | private: 136 | const uint8_t* m_ptrs = nullptr; 137 | }; 138 | 139 | uint32_t num_nodes() { 140 | return m_nodes.size() / BYTES; 141 | } 142 | uint32_t num_emp_nodes() { 143 | return 0; 144 | } 145 | }; 146 | 147 | } // namespace dyft 148 | -------------------------------------------------------------------------------- /include/mart_array_sparse.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #if defined(__AVX__) || defined(__AVX2__) 4 | #include 5 | #endif 6 | 7 | #include 8 | #include 9 | 10 | #include "abort_if.hpp" 11 | #include "mart_common.hpp" 12 | 13 | namespace dyft { 14 | 15 | template 16 | class mart_array_sparse { 17 | static_assert(2 <= K and K <= 255); 18 | 19 | public: 20 | static constexpr uint64_t PTR_SIZE = 5; 21 | static constexpr uint64_t KEYS_OFFSET = 1; 22 | static constexpr uint64_t PTRS_OFFSET = 1 + K; 23 | 24 | static constexpr uint64_t BYTES = 1 + K + (K * PTR_SIZE); // header + K keys + K ptrs 25 | static constexpr mart_node_types NTYPE = NodeType; 26 | 27 | private: 28 | std::vector m_nodes; 29 | uint32_t m_head_nid = MART_NILID; 30 | uint32_t m_num_emps = 0; 31 | 32 | static mart_pointer get_martptr(const uint8_t* ptrs) { 33 | return {*reinterpret_cast(ptrs), mart_node_types(ptrs[4])}; 34 | } 35 | static void set_martptr(uint8_t* ptrs, mart_pointer new_ptr) { 36 | *reinterpret_cast(ptrs) = new_ptr.nid; 37 | ptrs[4] = uint8_t(new_ptr.ntype); 38 | } 39 | 40 | public: 41 | mart_array_sparse() = default; 42 | 43 | mart_insert_flags insert_ptr(mart_cursor& mc, uint8_t c, mart_pointer new_ptr) { 44 | DEBUG_ABORT_IF_NE(mc.nptr.ntype, NTYPE); 45 | 46 | const uint64_t pos = mc.nptr.nid * BYTES; 47 | 48 | const uint32_t num = m_nodes[pos]; 49 | uint8_t* keys = &m_nodes[pos + KEYS_OFFSET]; 50 | uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 51 | 52 | DEBUG_ABORT_IF_LT(K, num); 53 | 54 | for (uint32_t i = 0; i < num; i++) { 55 | if (keys[i] == c) { // found? 56 | mc = mart_cursor{i, mc.nptr, get_martptr(ptrs + (i * PTR_SIZE))}; 57 | return MART_FOUND; 58 | } 59 | } 60 | 61 | // can insert 62 | if (num < K) { 63 | m_nodes[pos] += 1; 64 | keys[num] = c; 65 | set_martptr(ptrs + (num * PTR_SIZE), new_ptr); 66 | 67 | mc = mart_cursor{num, mc.nptr, new_ptr}; 68 | return MART_INSERTED; // inserted 69 | } 70 | 71 | return MART_NEEDED_TO_EXPAND; 72 | } 73 | 74 | mart_insert_flags append_ptr(mart_cursor& mc, uint8_t c, mart_pointer new_ptr) { 75 | DEBUG_ABORT_IF_NE(mc.nptr.ntype, NTYPE); 76 | 77 | const uint64_t pos = mc.nptr.nid * BYTES; 78 | 79 | const uint32_t num = m_nodes[pos]; 80 | uint8_t* keys = &m_nodes[pos + KEYS_OFFSET]; 81 | uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 82 | 83 | DEBUG_ABORT_IF_LE(K, num); 84 | 85 | // not searched 86 | m_nodes[pos] += 1; 87 | keys[num] = c; 88 | set_martptr(ptrs + (num * PTR_SIZE), new_ptr); 89 | 90 | // set src 91 | mc = mart_cursor{num, mc.nptr, new_ptr}; 92 | return MART_INSERTED; 93 | } 94 | 95 | void extract(const mart_cursor& mc, std::vector& edges) { 96 | DEBUG_ABORT_IF_NE(mc.nptr.ntype, NTYPE); 97 | 98 | const uint32_t nid = mc.nptr.nid; 99 | const uint64_t pos = nid * BYTES; 100 | 101 | const uint32_t num = m_nodes[pos]; 102 | const uint8_t* keys = &m_nodes[pos + KEYS_OFFSET]; 103 | const uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 104 | 105 | for (uint32_t i = 0; i < num; i++) { 106 | edges.push_back(mart_edge{keys[i], get_martptr(ptrs + (i * PTR_SIZE))}); 107 | } 108 | 109 | // To make empty 110 | if (m_head_nid == MART_NILID) { 111 | prev_ref(nid) = nid; 112 | next_ref(nid) = nid; 113 | m_head_nid = nid; 114 | } else { 115 | const uint32_t prev_nid = prev_ref(m_head_nid); 116 | prev_ref(nid) = prev_nid; 117 | next_ref(nid) = m_head_nid; 118 | next_ref(prev_nid) = nid; 119 | prev_ref(m_head_nid) = nid; 120 | } 121 | 122 | m_num_emps += 1; 123 | } 124 | 125 | mart_pointer make_node() { 126 | // Reuse empty element 127 | if (m_head_nid != MART_NILID) { 128 | m_num_emps -= 1; 129 | 130 | const uint32_t new_nid = m_head_nid; 131 | const uint32_t prev_nid = prev_ref(m_head_nid); 132 | const uint32_t next_nid = next_ref(m_head_nid); 133 | 134 | if (next_nid == m_head_nid) { 135 | m_head_nid = MART_NILID; 136 | } else { 137 | m_head_nid = next_nid; 138 | prev_ref(next_nid) = prev_nid; 139 | next_ref(prev_nid) = next_nid; 140 | } 141 | m_nodes[new_nid * BYTES] = 0; // fanout 142 | return mart_pointer{new_nid, NTYPE}; 143 | } 144 | 145 | if (m_nodes.size() == 0) { 146 | m_nodes.reserve(BYTES); 147 | } else if (m_nodes.size() + BYTES > m_nodes.capacity()) { 148 | m_nodes.reserve(m_nodes.capacity() * 2); 149 | } 150 | 151 | const uint64_t new_pos = m_nodes.size(); 152 | const uint32_t new_nid = static_cast(new_pos / BYTES); 153 | m_nodes.resize(new_pos + BYTES); 154 | m_nodes[new_pos] = 0; // fanout 155 | 156 | return mart_pointer{new_nid, NTYPE}; 157 | } 158 | 159 | // returns new nid 160 | mart_pointer make_node(const std::vector& edges) { 161 | ABORT_IF_OUT(edges.size(), 0, K); 162 | 163 | const mart_pointer new_nptr = make_node(); 164 | 165 | const uint32_t new_nid = new_nptr.nid; 166 | const uint64_t new_pos = new_nid * BYTES; 167 | 168 | uint8_t* keys = &m_nodes[new_pos + KEYS_OFFSET]; 169 | uint8_t* ptrs = &m_nodes[new_pos + PTRS_OFFSET]; 170 | 171 | m_nodes[new_pos] = static_cast(edges.size()); 172 | for (uint32_t i = 0; i < edges.size(); i++) { 173 | keys[i] = edges[i].c; 174 | set_martptr(ptrs + (i * PTR_SIZE), edges[i].nptr); 175 | } 176 | 177 | return new_nptr; 178 | } 179 | 180 | void update_srcptr(const mart_cursor& mc) { 181 | DEBUG_ABORT_IF_NE(mc.pptr.ntype, NTYPE); 182 | 183 | const uint64_t pos = mc.pptr.nid * BYTES; 184 | uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 185 | 186 | set_martptr(ptrs + (mc.offset * PTR_SIZE), mc.nptr); 187 | } 188 | 189 | mart_pointer find_child(mart_pointer nptr, uint8_t c) { 190 | DEBUG_ABORT_IF_NE(nptr.ntype, NTYPE); 191 | 192 | const uint32_t nid = nptr.nid; 193 | const uint64_t pos = nid * BYTES; 194 | 195 | const uint32_t num = m_nodes[pos]; 196 | const uint8_t* keys = &m_nodes[pos + KEYS_OFFSET]; 197 | const uint8_t* ptrs = &m_nodes[pos + PTRS_OFFSET]; 198 | 199 | return adaptive_find(num, keys, ptrs, c); 200 | } 201 | 202 | struct searcher { 203 | searcher() = default; 204 | 205 | searcher(const mart_array_sparse* obj, mart_pointer nptr) { 206 | set(obj, nptr); 207 | } 208 | 209 | void set(const mart_array_sparse* obj, mart_pointer nptr) { 210 | DEBUG_ABORT_IF_NE(nptr.ntype, NTYPE); 211 | 212 | const uint64_t pos = nptr.nid * BYTES; 213 | 214 | m_num = obj->m_nodes[pos]; 215 | m_keys = &obj->m_nodes[pos + KEYS_OFFSET]; 216 | m_ptrs = &obj->m_nodes[pos + PTRS_OFFSET]; 217 | } 218 | 219 | mart_pointer find(uint8_t c) { 220 | return adaptive_find(m_num, m_keys, m_ptrs, c); 221 | } 222 | void scan(std::vector& edges) { 223 | for (uint32_t i = 0; i < m_num; i++) { 224 | edges.push_back(mart_edge{m_keys[i], get_martptr(m_ptrs + (i * PTR_SIZE))}); 225 | } 226 | } 227 | 228 | private: 229 | uint32_t m_num = 0; 230 | const uint8_t* m_keys = nullptr; 231 | const uint8_t* m_ptrs = nullptr; 232 | }; 233 | 234 | void population_stats() { 235 | std::set emps = get_emps(); 236 | 237 | uint64_t cnts[K + 1]; 238 | std::fill(cnts, cnts + K + 1, 0); 239 | 240 | for (uint64_t i = 0; i < m_nodes.size(); i += BYTES) { 241 | if (emps.find(i) != emps.end()) { 242 | continue; 243 | } 244 | DEBUG_ABORT_IF_LT(K, m_nodes[i]); 245 | cnts[m_nodes[i]] += 1; 246 | } 247 | 248 | const double sum = std::accumulate(cnts, cnts + K + 1, 0.0); 249 | 250 | tfm::printfln("## mart_array_sparse (K=%d) ##", K); 251 | for (uint32_t k = 0; k <= K; k++) { 252 | if (cnts[k] == 0) continue; 253 | tfm::printfln("- k=%d: %d (%g)", k, cnts[k], cnts[k] / sum); 254 | } 255 | } 256 | 257 | uint32_t num_nodes() { 258 | return (m_nodes.size() / BYTES) - m_num_emps; 259 | } 260 | uint32_t num_emp_nodes() { 261 | return m_num_emps; 262 | } 263 | 264 | void debug_emplist() { 265 | uint32_t emps = 0; 266 | if (m_head_nid != MART_NILID) { 267 | uint32_t emp_nid = m_head_nid; 268 | while (true) { 269 | emps += 1; 270 | uint32_t next_nid = next_ref(emp_nid); 271 | if (next_nid == m_head_nid) break; 272 | emp_nid = next_nid; 273 | } 274 | } 275 | ABORT_IF_NE(m_num_emps, emps); 276 | } 277 | 278 | std::set get_emps() { 279 | std::set emps; 280 | if (m_head_nid != MART_NILID) { 281 | uint32_t emp_nid = m_head_nid; 282 | while (true) { 283 | emps.insert(emp_nid * BYTES); 284 | uint32_t next_nid = next_ref(emp_nid); 285 | if (next_nid == m_head_nid) break; 286 | emp_nid = next_nid; 287 | } 288 | } 289 | return emps; 290 | } 291 | 292 | private: 293 | uint32_t& prev_ref(uint32_t nid) { 294 | return reinterpret_cast(&m_nodes[nid * BYTES])[0]; 295 | } 296 | uint32_t& next_ref(uint32_t nid) { 297 | return reinterpret_cast(&m_nodes[nid * BYTES])[1]; 298 | } 299 | 300 | static mart_pointer adaptive_find(uint32_t num, const uint8_t* keys, const uint8_t* ptrs, uint8_t c) { 301 | if constexpr (K <= 8) { 302 | // Linear Scan 303 | for (uint32_t i = 0; i < num; i++) { 304 | if (keys[i] == c) { // found? 305 | return get_martptr(ptrs + (i * PTR_SIZE)); 306 | } 307 | } 308 | } else { 309 | #ifdef __AVX2__ 310 | __m256i c_256i = _mm256_set1_epi8(c); 311 | for (uint32_t i = 0; i < num; i += 32) { 312 | __m256i cmp = _mm256_cmpeq_epi8(c_256i, _mm256_loadu_si256((__m256i*)(keys + i))); 313 | const unsigned check_bits = _mm256_movemask_epi8(cmp) & ((1ULL << std::min(32U, num - i)) - 1ULL); 314 | if (check_bits) { 315 | const unsigned j = __builtin_ctz(check_bits) + i; 316 | return get_martptr(ptrs + (j * PTR_SIZE)); 317 | } 318 | } 319 | #elif __AVX__ 320 | __m128i c_128i = _mm_set1_epi8(c); 321 | for (uint32_t i = 0; i < num; i += 16) { 322 | __m128i cmp = _mm_cmpeq_epi8(c_128i, _mm_loadu_si128((__m128i*)(keys + i))); 323 | const unsigned check_bits = _mm_movemask_epi8(cmp) & ((1U << std::min(16U, num - i)) - 1U); 324 | if (check_bits) { 325 | const unsigned j = __builtin_ctz(check_bits) + i; 326 | return get_martptr(ptrs + (j * PTR_SIZE)); 327 | } 328 | } 329 | #else 330 | // Linear Scan 331 | for (uint32_t i = 0; i < num; i++) { 332 | if (keys[i] == c) { // found? 333 | return get_martptr(ptrs + (i * PTR_SIZE)); 334 | } 335 | } 336 | #endif 337 | } 338 | return make_mart_nullptr(); 339 | } 340 | }; 341 | 342 | } // namespace dyft 343 | -------------------------------------------------------------------------------- /include/mart_common.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "abort_if.hpp" 4 | 5 | namespace dyft { 6 | 7 | /* * * * * * * * * * * * 8 | * Basic definitions of MART 9 | */ 10 | enum mart_insert_flags { 11 | MART_FOUND, 12 | MART_INSERTED, 13 | MART_NEEDED_TO_EXPAND, 14 | }; 15 | 16 | #define MART_SPACE_EFFICIENT 17 | 18 | #ifdef MART_SPACE_EFFICIENT 19 | enum mart_node_types : uint8_t { 20 | MART_LEAF, 21 | MART_NODE_2, 22 | MART_NODE_4, 23 | MART_NODE_8, 24 | MART_NODE_16, 25 | MART_NODE_32, 26 | MART_NODE_64, 27 | MART_NODE_128, 28 | MART_NODE_256, 29 | MART_NIL_NTYPE, 30 | }; 31 | #else 32 | enum mart_node_types : uint8_t { 33 | MART_LEAF, 34 | MART_NODE_4, 35 | MART_NODE_32, 36 | MART_NODE_256, 37 | MART_NIL_NTYPE, 38 | }; 39 | #endif 40 | 41 | static constexpr uint32_t MART_NILID = UINT32_MAX; 42 | 43 | struct mart_pointer { 44 | uint32_t nid; 45 | mart_node_types ntype; 46 | }; 47 | 48 | inline mart_pointer make_mart_nullptr() { 49 | return mart_pointer{MART_NILID, MART_NIL_NTYPE}; 50 | } 51 | inline bool check_mart_nullptr(mart_pointer p) { 52 | return p.nid == MART_NILID; 53 | } 54 | 55 | struct mart_cursor { 56 | uint32_t offset; 57 | mart_pointer pptr; // src pointer 58 | mart_pointer nptr; 59 | }; 60 | 61 | struct mart_edge { 62 | uint8_t c; 63 | mart_pointer nptr; 64 | }; 65 | 66 | } // namespace dyft 67 | -------------------------------------------------------------------------------- /include/mi_frame.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "array_index.hpp" 4 | #include "art_index.hpp" 5 | #include "mart_index.hpp" 6 | 7 | namespace dyft { 8 | 9 | template 10 | class mi_frame : public dyft_interface { 11 | public: 12 | using index_type = Index; 13 | using vint_type = typename index_type::vint_type; 14 | 15 | static constexpr int LEN = index_type::LEN; 16 | 17 | private: 18 | const vcode_array* m_database = nullptr; 19 | const int m_blocks = 0; 20 | 21 | std::vector m_radii; 22 | std::vector> m_indexes; 23 | uint32_t m_ids = 0; 24 | 25 | // for query 26 | std::vector m_cands; 27 | 28 | public: 29 | mi_frame(const vcode_array* database, int radius, int blocks, int splitthr, double in_weight) 30 | : m_database(database), m_blocks(blocks), m_radii(blocks), m_indexes(blocks) { 31 | ABORT_IF_LE(blocks, 1); 32 | 33 | int beg = 0; // of chunk positios 34 | const int gph = radius - m_blocks + 1; 35 | const int chunks_size = index_type::get_chunks_size(m_database->get_bits()); 36 | 37 | for (int b = 0; b < m_blocks; b++) { 38 | const int m = (b + chunks_size) / m_blocks; 39 | m_radii[b] = (b + gph) / m_blocks; 40 | if (m_radii[b] >= 0) { 41 | m_indexes[b] = std::make_unique(database, m_radii[b], beg, beg + m, splitthr, in_weight); 42 | } 43 | beg += m; 44 | } 45 | ABORT_IF_NE(beg, chunks_size); 46 | 47 | m_cands.reserve(1U << 10); 48 | } 49 | 50 | uint32_t append() override { 51 | for (int b = 0; b < m_blocks; b++) { 52 | if (m_indexes[b]) { 53 | ABORT_IF_NE(m_indexes[b]->append(), m_ids); 54 | } 55 | } 56 | return m_ids++; 57 | } 58 | 59 | void range_search(const vint_type* vcode, const std::function& fn) override { 60 | m_cands.clear(); 61 | 62 | if (selects_ls()) { 63 | for (uint32_t id = 0; id < get_size(); id++) { 64 | fn(id); 65 | } 66 | return; 67 | } 68 | 69 | for (int b = 0; b < m_blocks; b++) { 70 | if (m_radii[b] >= 0) { 71 | DEBUG_ABORT_IF(!m_indexes[b]); 72 | m_indexes[b]->trie_search(vcode, [&](uint32_t bi) { m_cands.push_back(bi); }); 73 | } 74 | } 75 | 76 | std::sort(m_cands.begin(), m_cands.end()); 77 | m_cands.erase(std::unique(m_cands.begin(), m_cands.end()), m_cands.end()); 78 | 79 | for (uint32_t id : m_cands) { 80 | fn(id); 81 | } 82 | } 83 | 84 | uint32_t get_size() const override { 85 | return m_ids; 86 | } 87 | int get_bits() const override { 88 | return m_database->get_bits(); 89 | } 90 | 91 | uint32_t get_leaves() const override { 92 | return 0; 93 | } 94 | size_t get_split_count() override { 95 | return 0; 96 | } 97 | 98 | bool selects_ls() const override { 99 | for (int b = 0; b < m_blocks; b++) { 100 | if (m_indexes[b] and m_indexes[b]->selects_ls()) { 101 | return true; 102 | } 103 | } 104 | return false; 105 | } 106 | 107 | double get_ls_cost() const override { 108 | return 0.0; 109 | } 110 | double get_trie_cost() const override { 111 | return 0.0; 112 | } 113 | 114 | void innode_stats() override {} 115 | void population_stats() override {} 116 | }; 117 | 118 | } // namespace dyft 119 | -------------------------------------------------------------------------------- /include/misc.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __APPLE__ 4 | #include 5 | #endif 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include "bit_tools.hpp" 22 | 23 | namespace dyft { 24 | 25 | template 26 | inline float get_average(const std::vector& vec) { 27 | return std::accumulate(vec.begin(), vec.end(), 0.0) / vec.size(); 28 | } 29 | 30 | inline std::string normalize_filepath(std::string filepath) { 31 | std::replace(filepath.begin(), filepath.end(), '/', '_'); 32 | std::replace(filepath.begin(), filepath.end(), '.', '_'); 33 | std::replace(filepath.begin(), filepath.end(), ':', '_'); 34 | return filepath; 35 | } 36 | 37 | inline int get_hamdist(uint64_t x, uint64_t y) { 38 | return static_cast(bit_tools::popcnt(x ^ y)); 39 | } 40 | 41 | // ceil(a / b), cf. https://nariagari-igakusei.com/cpp-division-round-up/ 42 | inline constexpr int ceil_div(int a, int b) { 43 | return (a + (b - 1)) / b; 44 | } 45 | 46 | // From Cedar (http://www.tkl.iis.u-tokyo.ac.jp/~ynaga/cedar/) 47 | inline size_t get_process_size_in_bytes() { 48 | #ifdef __APPLE__ 49 | struct task_basic_info t_info; 50 | mach_msg_type_number_t t_info_count = TASK_BASIC_INFO_COUNT; 51 | task_info(current_task(), TASK_BASIC_INFO, reinterpret_cast(&t_info), &t_info_count); 52 | return t_info.resident_size; 53 | #else 54 | FILE* fp = std::fopen("/proc/self/statm", "r"); 55 | size_t dummy(0), vm(0); 56 | std::fscanf(fp, "%ld %ld ", &dummy, &vm); // get resident (see procfs) 57 | std::fclose(fp); 58 | return vm * ::getpagesize(); 59 | #endif 60 | } 61 | 62 | inline double to_MiB(size_t bytes) { 63 | return bytes / (1024.0 * 1024.0); 64 | } 65 | 66 | inline double to_GiB(size_t bytes) { 67 | return bytes / (1024.0 * 1024.0 * 1024.0); 68 | } 69 | 70 | } // namespace dyft 71 | -------------------------------------------------------------------------------- /include/sparse_group.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "abort_if.hpp" 4 | #include "bit_tools.hpp" 5 | 6 | namespace dyft { 7 | 8 | class sparse_group { 9 | public: 10 | static constexpr uint32_t SIZE = 64; 11 | 12 | private: 13 | uint64_t m_bitmap = 0; 14 | std::vector m_group; 15 | 16 | public: 17 | sparse_group() = default; 18 | 19 | std::pair access(uint32_t idx) { 20 | ABORT_IF_LE(SIZE, idx); 21 | 22 | if ((m_bitmap & (1ULL << idx)) == 0ULL) { 23 | return {nullptr, nullptr}; 24 | } 25 | 26 | const uint64_t bitmask = (1ULL << idx) - 1ULL; 27 | const int howmany = bit_tools::popcnt(m_bitmap & bitmask); 28 | const int totones = bit_tools::popcnt(m_bitmap); 29 | 30 | const uint32_t size = m_group[howmany + 1] - m_group[howmany]; 31 | uint32_t* ptr = m_group.data() + (totones + 1 + m_group[howmany]); 32 | 33 | return {ptr, ptr + size}; 34 | } 35 | 36 | void insert(uint32_t idx, uint32_t data) { 37 | ABORT_IF_LE(SIZE, idx); 38 | 39 | if (m_bitmap == 0) { 40 | m_bitmap = 1ULL << idx; 41 | m_group = std::vector{0, 1, data}; 42 | return; 43 | } 44 | 45 | const uint64_t bitmask = (1ULL << idx) - 1ULL; 46 | const int howmany = bit_tools::popcnt(m_bitmap & bitmask); 47 | 48 | if ((m_bitmap & (1ULL << idx)) == 0ULL) { 49 | m_group.insert(m_group.begin() + howmany, m_group[howmany]); 50 | m_bitmap |= (1ULL << idx); 51 | } 52 | 53 | const int totones = bit_tools::popcnt(m_bitmap); 54 | 55 | // totones + 1 = dataの始点 56 | // m_group[howmany + 1] = idxの末尾 57 | m_group.insert(m_group.begin() + (totones + 1 + m_group[howmany + 1]), data); 58 | for (int i = howmany + 1; i <= totones; i++) { 59 | m_group[i] += 1; 60 | } 61 | } 62 | 63 | void insert(uint32_t idx, const std::vector& datvec) { 64 | ABORT_IF_LE(SIZE, idx); 65 | 66 | if (m_bitmap == 0) { 67 | m_bitmap = 1ULL << idx; 68 | m_group = std::vector{0, static_cast(datvec.size())}; 69 | std::copy(datvec.begin(), datvec.end(), std::back_inserter(m_group)); 70 | return; 71 | } 72 | 73 | const uint64_t bitmask = (1ULL << idx) - 1ULL; 74 | const int howmany = bit_tools::popcnt(m_bitmap & bitmask); 75 | 76 | if ((m_bitmap & (1ULL << idx)) == 0ULL) { 77 | m_group.insert(m_group.begin() + howmany, m_group[howmany]); 78 | m_bitmap |= (1ULL << idx); 79 | } 80 | 81 | const int totones = bit_tools::popcnt(m_bitmap); 82 | 83 | // totones + 1 = dataの始点 84 | // m_group[howmany + 1] = idxの末尾 85 | m_group.insert(m_group.begin() + (totones + 1 + m_group[howmany + 1]), datvec.begin(), datvec.end()); 86 | for (int i = howmany + 1; i <= totones; i++) { 87 | m_group[i] += datvec.size(); 88 | } 89 | } 90 | 91 | std::vector extract(uint32_t idx) { 92 | ABORT_IF_LE(SIZE, idx); 93 | 94 | if ((m_bitmap & (1ULL << idx)) == 0ULL) { 95 | return {}; 96 | } 97 | 98 | const uint64_t bitmask = (1ULL << idx) - 1ULL; 99 | const int howmany = bit_tools::popcnt(m_bitmap & bitmask); 100 | const int totones = bit_tools::popcnt(m_bitmap); 101 | 102 | const uint32_t size = m_group[howmany + 1] - m_group[howmany]; 103 | const uint32_t pos = totones + 1 + m_group[howmany]; 104 | 105 | std::vector vec(size); 106 | std::copy(&m_group[pos], &m_group[pos + size], vec.data()); 107 | 108 | for (int i = howmany + 2; i <= totones; i++) { 109 | m_group[i] = m_group[i] - size; 110 | } 111 | m_group.erase(m_group.begin() + pos, m_group.begin() + pos + size); 112 | m_group.erase(m_group.begin() + howmany + 1); 113 | 114 | m_bitmap = m_bitmap & ~(1ULL << idx); 115 | 116 | return vec; 117 | } 118 | 119 | uint32_t size(uint32_t idx) const { 120 | ABORT_IF_LE(SIZE, idx); 121 | 122 | if ((m_bitmap & (1ULL << idx)) == 0ULL) { 123 | return 0; 124 | } 125 | 126 | const uint64_t bitmask = (1ULL << idx) - 1ULL; 127 | const int howmany = bit_tools::popcnt(m_bitmap & bitmask); 128 | 129 | return m_group[howmany + 1] - m_group[howmany]; 130 | } 131 | }; 132 | 133 | } // namespace dyft 134 | -------------------------------------------------------------------------------- /include/sparse_table.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "sparse_group.hpp" 4 | 5 | namespace dyft { 6 | 7 | class sparse_table { 8 | private: 9 | std::vector m_groups; 10 | uint32_t m_size = 0; 11 | 12 | public: 13 | sparse_table() = default; 14 | 15 | std::pair access(uint32_t idx) { 16 | ABORT_IF_LE(m_size, idx); 17 | const uint32_t gpos = idx / sparse_group::SIZE; 18 | const uint32_t gmod = idx % sparse_group::SIZE; 19 | return m_groups[gpos].access(gmod); 20 | } 21 | 22 | void push_back() { 23 | if (m_size / sparse_group::SIZE == m_groups.size()) { 24 | m_groups.push_back(sparse_group()); 25 | } 26 | m_size += 1; 27 | } 28 | 29 | void push_back(const std::vector& datvec) { 30 | if (m_size / sparse_group::SIZE == m_groups.size()) { 31 | m_groups.push_back(sparse_group()); 32 | } 33 | const uint32_t gpos = m_size / sparse_group::SIZE; 34 | const uint32_t gmod = m_size % sparse_group::SIZE; 35 | m_groups[gpos].insert(gmod, datvec); 36 | m_size += 1; 37 | } 38 | 39 | void insert(uint32_t idx, uint32_t data) { 40 | ABORT_IF_LE(m_size, idx); 41 | const uint32_t gpos = idx / sparse_group::SIZE; 42 | const uint32_t gmod = idx % sparse_group::SIZE; 43 | m_groups[gpos].insert(gmod, data); 44 | } 45 | 46 | void insert(uint32_t idx, const std::vector& datvec) { 47 | ABORT_IF_LE(m_size, idx); 48 | const uint32_t gpos = idx / sparse_group::SIZE; 49 | const uint32_t gmod = idx % sparse_group::SIZE; 50 | m_groups[gpos].insert(gmod, datvec); 51 | } 52 | 53 | std::vector extract(uint32_t idx) { 54 | ABORT_IF_LE(m_size, idx); 55 | const uint32_t gpos = idx / sparse_group::SIZE; 56 | const uint32_t gmod = idx % sparse_group::SIZE; 57 | return m_groups[gpos].extract(gmod); 58 | } 59 | 60 | uint32_t size() const { 61 | return m_size; 62 | } 63 | 64 | uint32_t group_size(uint32_t idx) const { 65 | ABORT_IF_LE(m_size, idx); 66 | const uint32_t gpos = idx / sparse_group::SIZE; 67 | const uint32_t gmod = idx % sparse_group::SIZE; 68 | return m_groups[gpos].size(gmod); 69 | } 70 | }; 71 | 72 | } // namespace dyft 73 | -------------------------------------------------------------------------------- /include/splitmix.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace dyft { 6 | 7 | // From http://xoroshiro.di.unimi.it/splitmix64.c 8 | class splitmix64 { 9 | public: 10 | splitmix64(uint64_t seed) : x(seed){}; 11 | 12 | uint64_t next() { 13 | uint64_t z = (x += uint64_t(0x9E3779B97F4A7C15)); 14 | z = (z ^ (z >> 30)) * uint64_t(0xBF58476D1CE4E5B9); 15 | z = (z ^ (z >> 27)) * uint64_t(0x94D049BB133111EB); 16 | return z ^ (z >> 31); 17 | } 18 | 19 | private: 20 | uint64_t x; 21 | }; 22 | 23 | } // namespace dyft 24 | -------------------------------------------------------------------------------- /include/statistic_reporter.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "tinyformat/tinyformat.h" 19 | 20 | namespace dyft { 21 | 22 | class statistic_reporter { 23 | public: 24 | using value_type = boost::variant, // 7 32 | std::vector, // 8 33 | std::vector, // 9 34 | std::vector, // 10 35 | std::vector, // 11 36 | std::vector // 12 37 | >; 38 | using row_type = std::map; 39 | using table_type = std::vector; 40 | 41 | template 42 | using NthTypeOf = typename std::tuple_element>::type; 43 | 44 | template 45 | static auto& access(boost::variant& v) { 46 | using target = NthTypeOf; 47 | return boost::get(v); 48 | } 49 | 50 | template 51 | static auto& access(const boost::variant& v) { 52 | using target = NthTypeOf; 53 | return boost::get(v); 54 | } 55 | 56 | private: 57 | boost::posix_time::ptime m_date; 58 | std::map m_tags; 59 | std::map m_tables; 60 | 61 | public: 62 | statistic_reporter() : m_date(boost::posix_time::second_clock::local_time()) {} 63 | 64 | static statistic_reporter& get_instance() { 65 | static statistic_reporter instance; 66 | return instance; 67 | } 68 | 69 | void tag(const std::string& key, const value_type& val) { 70 | m_tags[key] = val; 71 | } 72 | void append(const std::string& key, const row_type& row) { 73 | auto itr = m_tables.find(key); 74 | if (itr == m_tables.end()) { 75 | m_tables[key] = table_type{row}; 76 | } else { 77 | itr->second.push_back(row); 78 | } 79 | } 80 | 81 | template 82 | static boost::property_tree::ptree vec_to_ptree(const std::vector& vec) { 83 | boost::property_tree::ptree pt; 84 | for (auto v : vec) { 85 | boost::property_tree::ptree ct; 86 | ct.put("", v); 87 | pt.push_back(std::make_pair("", ct)); 88 | } 89 | return pt; 90 | } 91 | 92 | template 93 | static void update_ptree(boost::property_tree::ptree& pt, const std::string& key, const value_type& val) { 94 | if constexpr (N <= 6) { 95 | if (val.which() == N) { 96 | pt.put(key, access(val)); 97 | } else { 98 | update_ptree(pt, key, val); 99 | } 100 | } else if constexpr (N <= 12) { 101 | if (val.which() == N) { 102 | pt.add_child(key, vec_to_ptree(access(val))); 103 | } else { 104 | update_ptree(pt, key, val); 105 | } 106 | } 107 | } 108 | 109 | boost::property_tree::ptree make_ptree() const { 110 | boost::property_tree::ptree root; 111 | root.put("date", boost::posix_time::to_iso_extended_string(m_date)); 112 | for (const auto& v : m_tags) { 113 | update_ptree<0>(root, tfm::format("tags.%s", v.first), v.second); 114 | } 115 | for (const auto& t : m_tables) { 116 | boost::property_tree::ptree c; 117 | for (const auto& r : t.second) { 118 | boost::property_tree::ptree cc; 119 | for (const auto& v : r) { 120 | update_ptree<0>(cc, tfm::format("%s", v.first), v.second); 121 | } 122 | c.push_back(std::make_pair("", cc)); 123 | } 124 | root.add_child(t.first, c); 125 | } 126 | return root; 127 | } 128 | 129 | void save_json(std::string path) const { 130 | boost::property_tree::write_json(path, make_ptree()); 131 | } 132 | }; 133 | 134 | #define STATISTIC_TAG(key, val) statistic_reporter::get_instance().tag(std::string(key), val) 135 | #define STATISTIC_APPEND(key, ...) \ 136 | statistic_reporter::get_instance().append(std::string(key), statistic_reporter::row_type(__VA_ARGS__)) 137 | #define STATISTIC_SAVE(path) statistic_reporter::get_instance().save_json(path) 138 | 139 | } // namespace dyft 140 | -------------------------------------------------------------------------------- /include/timer.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "tinyformat/tinyformat.h" 6 | 7 | #define START_TIMER(___name) boost::timer::cpu_timer ___name 8 | 9 | #define STOP_TIMER(___name) ___name.stop() 10 | 11 | #define STOP_TIMER_V(___name) \ 12 | do { \ 13 | ___name.stop(); \ 14 | tfm::reportf("[%s] %s", #___name, ___name.format(4).c_str()); \ 15 | } while (false); 16 | 17 | #define GET_TIMER_SEC(___name) ___name.elapsed().wall / 1000000000.0 18 | 19 | #define GET_TIMER_MILLISEC(___name) ___name.elapsed().wall / 1000000.0 20 | -------------------------------------------------------------------------------- /include/vcode_array.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "abort_if.hpp" 7 | #include "vcode_tools.hpp" 8 | 9 | namespace dyft { 10 | 11 | template 12 | class vcode_array { 13 | public: 14 | using vint_type = typename vcode_tools::vint_type; 15 | 16 | private: 17 | uint32_t m_size = 0; 18 | std::vector m_vcodes; 19 | 20 | const int m_bits = 0; 21 | 22 | public: 23 | explicit vcode_array(int bits) : m_bits(bits) { 24 | ABORT_IF_OUT(m_bits, 1, 8); 25 | } 26 | explicit vcode_array(std::vector&& vcodes, int bits) 27 | : m_size(vcodes.size() / bits), m_vcodes(std::move(vcodes)), m_bits(bits) {} 28 | 29 | uint32_t append(const uint8_t* code) { 30 | vint_type vcode[8]; 31 | vcode_tools::to_vints(code, vcode, m_bits); 32 | std::copy(vcode, vcode + m_bits, std::back_inserter(m_vcodes)); 33 | return m_size++; 34 | } 35 | 36 | const vint_type* access(uint32_t id) const { 37 | ABORT_IF_LE(m_size, id); 38 | return &m_vcodes[id * m_bits]; 39 | } 40 | 41 | const vint_type* data() const { 42 | return m_vcodes.data(); 43 | } 44 | 45 | uint32_t get_size() const { 46 | return m_size; 47 | } 48 | int get_bits() const { 49 | return m_bits; 50 | } 51 | 52 | const std::vector& get_vcodes() const { 53 | return m_vcodes; 54 | } 55 | }; 56 | 57 | } // namespace dyft 58 | -------------------------------------------------------------------------------- /include/vcode_tools.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "bit_tools.hpp" 4 | 5 | namespace dyft { 6 | 7 | template 8 | struct vcode_traits; 9 | 10 | template <> 11 | struct vcode_traits<8> { 12 | using vint_type = uint8_t; 13 | }; 14 | template <> 15 | struct vcode_traits<16> { 16 | using vint_type = uint16_t; 17 | }; 18 | template <> 19 | struct vcode_traits<32> { 20 | using vint_type = uint32_t; 21 | }; 22 | template <> 23 | struct vcode_traits<64> { 24 | using vint_type = uint64_t; 25 | }; 26 | 27 | template 28 | struct vcode_tools { 29 | using vint_type = typename vcode_traits::vint_type; 30 | 31 | static int get_hamdist(const vint_type* x, const vint_type* y, int bits) { 32 | if (bits == 1) { 33 | return bit_tools::popcnt(x[0] ^ y[0]); 34 | } else { 35 | vint_type diff = 0; 36 | for (int j = 0; j < bits; ++j) { 37 | diff |= (x[j] ^ y[j]); 38 | } 39 | return bit_tools::popcnt(diff); 40 | } 41 | } 42 | 43 | static int get_hamdist(const vint_type* x, const vint_type* y, int bits, int radius) { 44 | if (bits == 1) { 45 | return bit_tools::popcnt(x[0] ^ y[0]); 46 | } else { 47 | int dist = 0; 48 | vint_type diff = 0; 49 | for (int j = 0; j < bits; ++j) { 50 | diff |= (x[j] ^ y[j]); 51 | dist = bit_tools::popcnt(diff); 52 | if (dist > radius) { 53 | return dist; 54 | } 55 | } 56 | return dist; 57 | } 58 | } 59 | 60 | static vint_type to_vint(const uint8_t* in, int j) { 61 | vint_type v = vint_type(0); 62 | for (int i = 0; i < N; ++i) { 63 | vint_type b = (in[i] >> j) & vint_type(1); 64 | v |= (b << i); 65 | } 66 | return v; 67 | } 68 | 69 | // static const vint_type* to_vints(const uint8_t* in, int bits) { 70 | // static vint_type out[64]; 71 | // for (int j = 0; j < bits; ++j) { 72 | // out[j] = to_vint(in, j); 73 | // } 74 | // return out; 75 | // } 76 | 77 | static void to_vints(const uint8_t* in, vint_type* out, int bits) { 78 | for (int j = 0; j < bits; ++j) { 79 | out[j] = to_vint(in, j); 80 | } 81 | } 82 | 83 | static uint8_t get_int(const vint_type* x, int i, int bits) { 84 | DEBUG_ABORT_IF_LE(N, i); 85 | 86 | if (bits == 1) { 87 | return (x[0] >> i) & uint8_t(1); 88 | } else { 89 | uint8_t c = uint8_t(0); 90 | for (int j = 0; j < bits; ++j) { 91 | c |= (((x[j] >> i) & uint8_t(1)) << j); 92 | } 93 | return c; 94 | } 95 | } 96 | }; 97 | 98 | } // namespace dyft 99 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 2 | 3 | add_executable(sample_bin sample_bin.cpp) 4 | target_link_libraries(sample_bin ${Boost_LIBRARIES}) 5 | 6 | add_executable(sample_int sample_int.cpp) 7 | target_link_libraries(sample_int ${Boost_LIBRARIES}) 8 | 9 | add_executable(simhash simhash.cpp) 10 | target_link_libraries(simhash ${Boost_LIBRARIES}) 11 | 12 | add_executable(precompute precompute.cpp) 13 | target_link_libraries(precompute ${Boost_LIBRARIES}) 14 | 15 | add_executable(bvecs_to_bin bvecs_to_bin.cpp) 16 | target_link_libraries(bvecs_to_bin ${Boost_LIBRARIES}) 17 | 18 | add_executable(gen_uniform gen_uniform.cpp) 19 | target_link_libraries(gen_uniform ${Boost_LIBRARIES}) 20 | 21 | 22 | # # # # # # # # # # # # # # # # 23 | # build_index_bin 24 | # # # # # # # # # # # # # # # # 25 | add_executable(build_index_bin_dyft build_index_bin.cpp) 26 | target_link_libraries(build_index_bin_dyft ${Boost_LIBRARIES}) 27 | set_target_properties(build_index_bin_dyft PROPERTIES COMPILE_DEFINITIONS "ALGO_DYFT") 28 | 29 | add_executable(build_index_bin_hms1v build_index_bin.cpp) 30 | target_link_libraries(build_index_bin_hms1v ${Boost_LIBRARIES}) 31 | set_target_properties(build_index_bin_hms1v PROPERTIES COMPILE_DEFINITIONS "ALGO_HMS1V") 32 | 33 | 34 | # # # # # # # # # # # # # # # # 35 | # build_index_int 36 | # # # # # # # # # # # # # # # # 37 | add_executable(build_index_int_dyft build_index_int.cpp) 38 | target_link_libraries(build_index_int_dyft ${Boost_LIBRARIES}) 39 | set_target_properties(build_index_int_dyft PROPERTIES COMPILE_DEFINITIONS "ALGO_DYFT") 40 | 41 | add_executable(build_index_int_hms1dv build_index_int.cpp) 42 | target_link_libraries(build_index_int_hms1dv ${Boost_LIBRARIES}) 43 | set_target_properties(build_index_int_hms1dv PROPERTIES COMPILE_DEFINITIONS "ALGO_HMS1DV") 44 | 45 | add_executable(build_index_int_gv build_index_int.cpp) 46 | target_link_libraries(build_index_int_gv ${Boost_LIBRARIES}) 47 | set_target_properties(build_index_int_gv PROPERTIES COMPILE_DEFINITIONS "ALGO_GV") 48 | 49 | 50 | # # # # # # # # # # # # # # # # 51 | # range_search_bin 52 | # # # # # # # # # # # # # # # # 53 | add_executable(range_search_bin_ls range_search_bin.cpp) 54 | target_link_libraries(range_search_bin_ls ${Boost_LIBRARIES}) 55 | set_target_properties(range_search_bin_ls PROPERTIES COMPILE_DEFINITIONS "ALGO_LS") 56 | 57 | add_executable(range_search_bin_dyft range_search_bin.cpp) 58 | target_link_libraries(range_search_bin_dyft ${Boost_LIBRARIES}) 59 | set_target_properties(range_search_bin_dyft PROPERTIES COMPILE_DEFINITIONS "ALGO_DYFT") 60 | 61 | add_executable(range_search_bin_hms1v range_search_bin.cpp) 62 | target_link_libraries(range_search_bin_hms1v ${Boost_LIBRARIES}) 63 | set_target_properties(range_search_bin_hms1v PROPERTIES COMPILE_DEFINITIONS "ALGO_HMS1V") 64 | 65 | 66 | # # # # # # # # # # # # # # # # 67 | # range_search_int 68 | # # # # # # # # # # # # # # # # 69 | add_executable(range_search_int_ls range_search_int.cpp) 70 | target_link_libraries(range_search_int_ls ${Boost_LIBRARIES}) 71 | set_target_properties(range_search_int_ls PROPERTIES COMPILE_DEFINITIONS "ALGO_LS") 72 | 73 | add_executable(range_search_int_dyft range_search_int.cpp) 74 | target_link_libraries(range_search_int_dyft ${Boost_LIBRARIES}) 75 | set_target_properties(range_search_int_dyft PROPERTIES COMPILE_DEFINITIONS "ALGO_DYFT") 76 | 77 | add_executable(range_search_int_hms1dv range_search_int.cpp) 78 | target_link_libraries(range_search_int_hms1dv ${Boost_LIBRARIES}) 79 | set_target_properties(range_search_int_hms1dv PROPERTIES COMPILE_DEFINITIONS "ALGO_HMS1DV") 80 | 81 | add_executable(range_search_int_gv range_search_int.cpp) 82 | target_link_libraries(range_search_int_gv ${Boost_LIBRARIES}) 83 | set_target_properties(range_search_int_gv PROPERTIES COMPILE_DEFINITIONS "ALGO_GV") 84 | -------------------------------------------------------------------------------- /src/build_index_bin.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #ifdef ALGO_DYFT 10 | #include 11 | #endif 12 | #ifdef ALGO_HMS1V 13 | #include 14 | #endif 15 | 16 | using namespace dyft; 17 | 18 | static constexpr size_t ABORT_THRESHOLD_GiB = 100; 19 | 20 | constexpr uint32_t SCALES[] = { 21 | 10, 100, 1'000, 10'000, 100'000, 1'000'000, 10'000'000, 100'000'000, 1'000'000'000, UINT32_MAX, 22 | }; 23 | 24 | cmd_line_parser::parser make_parser(int argc, char** argv) { 25 | cmd_line_parser::parser p(argc, argv); 26 | p.add("base_path", "input file path of database (in 64-bit binary)"); 27 | p.add("result_dir", "output directory path of results", "-o", false); 28 | p.add("radius", "Hamming distance threshold", "-R", false); 29 | p.add("ints", "number of ints for each vector (32|64)", "-N", false); 30 | #if defined(ALGO_DYFT) 31 | p.add("algorithm", "algorithm (array|art|mart)", "-A", false); 32 | p.add("splitthr", "split threshold (0 means optimal assignment)", "-T", false); 33 | p.add("in_weight", "weigth of innode", "-W", false); 34 | p.add("blocks", "number of blocks (0 means reasonable number based on radius)", "-K", false); 35 | #endif 36 | ABORT_IF(!p.parse()); 37 | return p; 38 | } 39 | 40 | template 41 | int main_template(const cmd_line_parser::parser& p) { 42 | const auto base_path = p.get("base_path"); 43 | const auto result_dir = p.get("result_dir", "results"); 44 | const auto radius = p.get("radius", 2); 45 | #if defined(ALGO_DYFT) 46 | const auto algorithm = p.get("algorithm", "mart"); 47 | const auto splitthr = p.get("splitthr", 0); 48 | const auto in_weight = p.get("in_weight", 1.0); 49 | const auto blocks = p.get("blocks", 2); 50 | #endif 51 | 52 | #ifdef ALGO_DYFT 53 | ABORT_IF((algorithm != "array") and (algorithm != "art") and (algorithm != "mart")); 54 | tfm::printfln("algorithm: %s", algorithm); 55 | tfm::printfln("splitthr: %d", splitthr); 56 | tfm::printfln("in_weight: %g", in_weight); 57 | tfm::printfln("blocks: %d", blocks); 58 | STATISTIC_TAG("algorithm", algorithm); 59 | STATISTIC_TAG("splitthr", splitthr); 60 | STATISTIC_TAG("in_weight", in_weight); 61 | STATISTIC_TAG("blocks", blocks); 62 | #endif 63 | #ifdef ALGO_HMS1V 64 | tfm::printfln("algorithm: hms1v"); 65 | STATISTIC_TAG("algorithm", "hms1v"); 66 | #endif 67 | 68 | tfm::printfln("ints: %d", N); 69 | tfm::printfln("radius: %d", radius); 70 | STATISTIC_TAG("ints", N); 71 | STATISTIC_TAG("radius", radius); 72 | 73 | const auto base_codes = load_vcodes_from_bin(base_path); 74 | tfm::printfln("base_path: %s", base_path); 75 | STATISTIC_TAG("base_path", base_path); 76 | 77 | const uint32_t base_size = base_codes->get_size(); 78 | tfm::printfln("base_size: %d", base_size); 79 | STATISTIC_TAG("base_size", base_size); 80 | 81 | std::vector test_scales; 82 | for (uint32_t i = 0; SCALES[i] < base_size; i++) { 83 | test_scales.push_back(SCALES[i]); 84 | } 85 | test_scales.push_back(base_size); 86 | 87 | std::vector process_sizes(test_scales.size()); 88 | std::vector insertion_times(test_scales.size()); 89 | 90 | uint32_t prev_size = 0; 91 | uint64_t init_process_bytes = get_process_size_in_bytes(); 92 | double insert_time_in_sec = 0.0; 93 | 94 | #ifdef ALGO_DYFT 95 | std::unique_ptr> index; 96 | if (algorithm == "array") { 97 | if (blocks <= 1) { 98 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 99 | } else { 100 | index = std::make_unique>>(base_codes.get(), radius, blocks, splitthr, in_weight); 101 | } 102 | } else if (algorithm == "art") { 103 | if (blocks <= 1) { 104 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 105 | } else { 106 | index = std::make_unique>>(base_codes.get(), radius, blocks, splitthr, in_weight); 107 | } 108 | } else { // algorithm == "mart" 109 | if (blocks <= 1) { 110 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 111 | } else { 112 | index = std::make_unique>>(base_codes.get(), radius, blocks, splitthr, in_weight); 113 | } 114 | } 115 | #endif 116 | #ifdef ALGO_HMS1V 117 | auto index = std::make_unique>(base_codes.get(), radius); 118 | #endif 119 | 120 | for (uint32_t i = 0; i < test_scales.size(); i++) { 121 | const uint32_t test_size = test_scales[i]; 122 | tfm::printfln("# %d codes...", test_size); 123 | 124 | START_TIMER(Insert); 125 | for (uint32_t bi = prev_size; bi < test_size; bi++) { 126 | ABORT_IF_NE(bi, index->append()); 127 | } 128 | STOP_TIMER_V(Insert); 129 | 130 | const size_t process_size = get_process_size_in_bytes(); 131 | process_sizes[i] = process_size - init_process_bytes; 132 | 133 | insert_time_in_sec += GET_TIMER_SEC(Insert); 134 | insertion_times[i] = insert_time_in_sec; 135 | 136 | tfm::reportfln("process size in MiB: %g", process_sizes[i] / (1024.0 * 1024.0)); 137 | tfm::reportfln("insertion time in sec: %g", insertion_times[i]); 138 | 139 | if (to_GiB(process_size) >= ABORT_THRESHOLD_GiB) { 140 | tfm::warnfln("Abort build becasue the memory exceeds %d GiB", ABORT_THRESHOLD_GiB); 141 | test_scales.resize(i + 1); 142 | break; 143 | } 144 | 145 | prev_size = test_size; 146 | } 147 | 148 | for (uint32_t i = 0; i < test_scales.size(); i++) { 149 | const uint32_t test_size = test_scales[i]; 150 | const uint64_t process_size = process_sizes[i]; 151 | const double insertion_time = insertion_times[i]; 152 | STATISTIC_APPEND("insertion", {{"num_codes", test_size}, // 153 | {"process_bytes", process_size}, // 154 | {"insertion_time_in_sec", insertion_time}}); 155 | } 156 | 157 | #ifdef ALGO_DYFT 158 | const auto result_path = tfm::format("%s/build_index_bin_%s-%s-%dN-%dR-%dT-%gW-%dK.json", // 159 | result_dir, algorithm, normalize_filepath(base_path), // 160 | N, radius, splitthr, in_weight * 100, blocks); 161 | #endif 162 | #ifdef ALGO_HMS1V 163 | const auto result_path = tfm::format("%s/build_index_bin_hms1v-%s-%dN-%dR.json", // 164 | result_dir, normalize_filepath(base_path), N, radius); 165 | #endif 166 | 167 | make_directory(result_dir); 168 | STATISTIC_SAVE(result_path); 169 | tfm::printfln("wrote %s", result_path); 170 | 171 | return 0; 172 | } 173 | 174 | int main(int argc, char** argv) { 175 | #ifndef NDEBUG 176 | tfm::warnfln("The code is running in debug mode."); 177 | #endif 178 | 179 | auto p = make_parser(argc, argv); 180 | const auto ints = p.get("ints", 64); 181 | 182 | switch (ints) { 183 | case 32: 184 | return main_template<32>(p); 185 | case 64: 186 | return main_template<64>(p); 187 | default: 188 | break; 189 | } 190 | 191 | return 1; 192 | } -------------------------------------------------------------------------------- /src/build_index_int.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #ifdef ALGO_DYFT 10 | #include 11 | #endif 12 | #ifdef ALGO_HMS1DV 13 | #include 14 | #endif 15 | #ifdef ALGO_GV 16 | #include 17 | #endif 18 | 19 | using namespace dyft; 20 | 21 | static constexpr size_t ABORT_THRESHOLD_GiB = 100; 22 | 23 | constexpr uint32_t SCALES[] = { 24 | 10, 100, 1'000, 10'000, 100'000, 1'000'000, 10'000'000, 100'000'000, 1'000'000'000, UINT32_MAX, 25 | }; 26 | 27 | cmd_line_parser::parser make_parser(int argc, char** argv) { 28 | cmd_line_parser::parser p(argc, argv); 29 | p.add("base_path", "input file path of database (in 64-bit binary)"); 30 | p.add("result_dir", "output directory path of results", "-o", false); 31 | p.add("radius", "Hamming distance threshold", "-R", false); 32 | p.add("ints", "number of ints for each vector (32|64)", "-N", false); 33 | p.add("bits", "number of bits for each integer [1,8]", "-B", false); 34 | #if defined(ALGO_DYFT) 35 | p.add("algorithm", "algorithm (array|art|mart)", "-A", false); 36 | p.add("splitthr", "split threshold (0 means optimal assignment)", "-T", false); 37 | p.add("in_weight", "weigth of innode", "-W", false); 38 | p.add("blocks", "number of blocks (0 means reasonable number based on radius)", "-K", false); 39 | #endif 40 | ABORT_IF(!p.parse()); 41 | return p; 42 | } 43 | 44 | template 45 | int main_template(const cmd_line_parser::parser& p) { 46 | const auto base_path = p.get("base_path"); 47 | const auto result_dir = p.get("result_dir", "results"); 48 | const auto radius = p.get("radius", 2); 49 | const auto bits = p.get("bits", 4); 50 | #if defined(ALGO_DYFT) 51 | const auto algorithm = p.get("algorithm", "mart"); 52 | const auto splitthr = p.get("splitthr", 0); 53 | const auto in_weight = p.get("in_weight", 1.0); 54 | const auto blocks = p.get("blocks", 2); 55 | #endif 56 | 57 | #ifdef ALGO_DYFT 58 | ABORT_IF((algorithm != "array") and (algorithm != "art") and (algorithm != "mart")); 59 | tfm::printfln("algorithm: %s", algorithm); 60 | tfm::printfln("splitthr: %d", splitthr); 61 | tfm::printfln("in_weight: %g", in_weight); 62 | tfm::printfln("blocks: %d", blocks); 63 | STATISTIC_TAG("algorithm", algorithm); 64 | STATISTIC_TAG("splitthr", splitthr); 65 | STATISTIC_TAG("in_weight", in_weight); 66 | STATISTIC_TAG("blocks", blocks); 67 | #endif 68 | #ifdef ALGO_HMS1DV 69 | tfm::printfln("algorithm: hms1dv"); 70 | STATISTIC_TAG("algorithm", "hms1dv"); 71 | #endif 72 | #ifdef ALGO_GV 73 | tfm::printfln("algorithm: gv"); 74 | STATISTIC_TAG("algorithm", "gv"); 75 | #endif 76 | 77 | tfm::printfln("ints: %d", N); 78 | tfm::printfln("bits: %d", bits); 79 | tfm::printfln("radius: %d", radius); 80 | STATISTIC_TAG("ints", N); 81 | STATISTIC_TAG("bits", bits); 82 | STATISTIC_TAG("radius", radius); 83 | 84 | const auto base_codes = load_vcodes_from_bvecs(base_path, bits); 85 | tfm::printfln("base_path: %s", base_path); 86 | STATISTIC_TAG("base_path", base_path); 87 | 88 | const uint32_t base_size = base_codes->get_size(); 89 | tfm::printfln("base_size: %d", base_size); 90 | STATISTIC_TAG("base_size", base_size); 91 | 92 | std::vector test_scales; 93 | for (uint32_t i = 0; SCALES[i] < base_size; i++) { 94 | test_scales.push_back(SCALES[i]); 95 | } 96 | test_scales.push_back(base_size); 97 | 98 | std::vector process_sizes(test_scales.size()); 99 | std::vector insertion_times(test_scales.size()); 100 | 101 | uint32_t prev_size = 0; 102 | uint64_t init_process_bytes = get_process_size_in_bytes(); 103 | double insert_time_in_sec = 0.0; 104 | 105 | #ifdef ALGO_DYFT 106 | std::unique_ptr> index; 107 | 108 | if (algorithm == "array") { 109 | if (blocks == 1) { 110 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 111 | } else { 112 | index = std::make_unique>>( 113 | base_codes.get(), radius, blocks == 0 ? radius / 2 + 1 : blocks, splitthr, in_weight); 114 | } 115 | } else if (algorithm == "art") { 116 | if (blocks == 1) { 117 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 118 | } else { 119 | index = std::make_unique>>( 120 | base_codes.get(), radius, blocks == 0 ? radius / 2 + 1 : blocks, splitthr, in_weight); 121 | } 122 | } else { // algorithm == "mart" 123 | if (blocks == 1) { 124 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 125 | } else { 126 | index = std::make_unique>>( 127 | base_codes.get(), radius, blocks == 0 ? radius / 2 + 1 : blocks, splitthr, in_weight); 128 | } 129 | } 130 | #endif 131 | #if defined(ALGO_HMS1DV) 132 | auto index = std::make_unique>(base_codes.get(), radius); 133 | #endif 134 | #if defined(ALGO_GV) 135 | auto index = std::make_unique>(base_codes.get(), radius); 136 | #endif 137 | 138 | for (uint32_t i = 0; i < test_scales.size(); i++) { 139 | const uint32_t test_size = test_scales[i]; 140 | tfm::printfln("# %d codes...", test_size); 141 | 142 | START_TIMER(Insert); 143 | for (uint32_t bi = prev_size; bi < test_size; bi++) { 144 | ABORT_IF_NE(bi, index->append()); 145 | } 146 | STOP_TIMER_V(Insert); 147 | 148 | const size_t process_size = get_process_size_in_bytes(); 149 | process_sizes[i] = process_size - init_process_bytes; 150 | 151 | insert_time_in_sec += GET_TIMER_SEC(Insert); 152 | insertion_times[i] = insert_time_in_sec; 153 | 154 | tfm::reportfln("process size in MiB: %g", process_sizes[i] / (1024.0 * 1024.0)); 155 | tfm::reportfln("insertion time in sec: %g", insertion_times[i]); 156 | 157 | if (to_GiB(process_size) >= ABORT_THRESHOLD_GiB) { 158 | tfm::warnfln("Abort build becasue the memory exceeds %d GiB", ABORT_THRESHOLD_GiB); 159 | test_scales.resize(i + 1); 160 | break; 161 | } 162 | 163 | prev_size = test_size; 164 | } 165 | 166 | for (uint32_t i = 0; i < test_scales.size(); i++) { 167 | const uint32_t test_size = test_scales[i]; 168 | const uint64_t process_size = process_sizes[i]; 169 | const double insertion_time = insertion_times[i]; 170 | STATISTIC_APPEND("insertion", {{"num_codes", test_size}, // 171 | {"process_bytes", process_size}, // 172 | {"insertion_time_in_sec", insertion_time}}); 173 | } 174 | 175 | #ifdef ALGO_DYFT 176 | const auto result_path = tfm::format("%s/build_index_int_%s-%s-%dN-%dB-%dR-%dT-%gW-%dK.json", // 177 | result_dir, algorithm, normalize_filepath(base_path), // 178 | N, bits, radius, splitthr, in_weight * 100, blocks); 179 | #endif 180 | #ifdef ALGO_HMS1DV 181 | const auto result_path = tfm::format("%s/build_index_int_hms1dv-%s-%dN-%dB-%dR.json", // 182 | result_dir, normalize_filepath(base_path), N, bits, radius); 183 | #endif 184 | #ifdef ALGO_GV 185 | const auto result_path = tfm::format("%s/build_index_int_gv-%s-%dN-%dB-%dR.json", // 186 | result_dir, normalize_filepath(base_path), N, bits, radius); 187 | #endif 188 | 189 | make_directory(result_dir); 190 | STATISTIC_SAVE(result_path); 191 | tfm::printfln("wrote %s", result_path); 192 | 193 | return 0; 194 | } 195 | 196 | int main(int argc, char** argv) { 197 | #ifndef NDEBUG 198 | tfm::warnfln("The code is running in debug mode."); 199 | #endif 200 | 201 | auto p = make_parser(argc, argv); 202 | const auto ints = p.get("ints", 64); 203 | 204 | switch (ints) { 205 | case 32: 206 | return main_template<32>(p); 207 | case 64: 208 | return main_template<64>(p); 209 | default: 210 | break; 211 | } 212 | 213 | return 1; 214 | } -------------------------------------------------------------------------------- /src/bvecs_to_bin.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | using namespace dyft; 8 | 9 | std::vector sample_codes(const std::vector& codes, uint32_t num, uint64_t seed) { 10 | splitmix64 engine(seed); 11 | std::vector sampled(num); 12 | for (uint32_t i = 0; i < num; ++i) { 13 | sampled[i] = codes[engine.next() % codes.size()]; 14 | } 15 | return sampled; 16 | } 17 | 18 | int main(int argc, char** argv) { 19 | #ifndef NDEBUG 20 | tfm::warnfln("The code is running in debug mode."); 21 | #endif 22 | 23 | cmd_line_parser::parser p(argc, argv); 24 | p.add("input_path", "input file path of codes (in bvecs)"); 25 | p.add("output_path", "output file path of sampled codes (in 64-binary)"); 26 | ABORT_IF(!p.parse()); 27 | 28 | auto input_path = p.get("input_path"); 29 | auto output_path = p.get("output_path"); 30 | 31 | const auto int_codes = load_vcodes_from_bvecs<64>(input_path, 1); 32 | 33 | auto ofs = make_ofstream(output_path); 34 | for (uint32_t i = 0; i < int_codes->get_size(); i++) { 35 | const uint64_t* code = int_codes->access(i); 36 | ofs.write(reinterpret_cast(code), sizeof(uint64_t)); 37 | } 38 | 39 | return 0; 40 | } -------------------------------------------------------------------------------- /src/gen_uniform.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace dyft; 10 | 11 | static constexpr uint32_t DIM = 64; 12 | 13 | int main(int argc, char** argv) { 14 | #ifndef NDEBUG 15 | tfm::warnfln("The code is running in debug mode."); 16 | #endif 17 | 18 | cmd_line_parser::parser p(argc, argv); 19 | p.add("output_path", "output file path of random codes (in bvecs)"); 20 | p.add("num", "number of codes"); 21 | p.add("seed", "seed", "-s", false); 22 | ABORT_IF(!p.parse()); 23 | 24 | auto output_path = p.get("output_path"); 25 | auto num = p.get("num"); 26 | auto seed = p.get("seed", 114514); 27 | 28 | std::default_random_engine engine(seed); 29 | std::uniform_int_distribution<> dist(0, 255); 30 | 31 | auto ofs = make_ofstream(output_path); 32 | 33 | for (size_t i = 0; i < num; i++) { 34 | uint8_t code[DIM]; 35 | for (size_t j = 0; j < DIM; j++) { 36 | code[j] = static_cast(dist(engine)); 37 | } 38 | ofs.write(reinterpret_cast(&DIM), sizeof(uint32_t)); 39 | ofs.write(reinterpret_cast(code), sizeof(uint8_t) * DIM); 40 | } 41 | 42 | return 0; 43 | } -------------------------------------------------------------------------------- /src/precompute.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | static uint32_t HDT[256][256]; 4 | 5 | void compute_hdt(uint32_t bits) { 6 | const uint32_t N = 8 / bits; 7 | const uint32_t U = 1 << (N * bits); 8 | const uint32_t X = (1 << bits) - 1; 9 | 10 | for (uint32_t i = 0; i < 256; i++) { 11 | for (uint32_t j = 0; j < 256; j++) { 12 | HDT[i][j] = 0; 13 | } 14 | } 15 | 16 | for (uint32_t i = 0; i < U; i++) { 17 | for (uint32_t j = 0; j < U; j++) { 18 | uint32_t d = 0; 19 | for (uint32_t k = 0; k < N; k++) { 20 | const uint32_t a = (i >> (k * bits)) & X; 21 | const uint32_t b = (j >> (k * bits)) & X; 22 | d += (a != b) ? 1 : 0; 23 | } 24 | HDT[i][j] = d; 25 | } 26 | } 27 | } 28 | 29 | static uint8_t LUT[256][256]; 30 | static uint32_t DB[8][9]; 31 | 32 | void compute_lut(uint32_t bits) { 33 | const uint32_t N = 8 / bits; 34 | const uint32_t U = 1 << (N * bits); 35 | const uint32_t X = (1 << bits) - 1; 36 | 37 | for (uint32_t k = 0; k < 9; k++) { 38 | DB[bits - 1][k] = 0; 39 | } 40 | for (uint32_t j = 0; j < U; j++) { 41 | uint32_t d = 0; 42 | for (uint32_t k = 0; k < N; k++) { 43 | const uint32_t b = (j >> (k * bits)) & X; 44 | d += (0 != b) ? 1 : 0; 45 | } 46 | DB[bits - 1][d] += 1; 47 | } 48 | 49 | for (uint32_t i = 0; i < U; i++) { 50 | std::pair row[256]; // c, d 51 | for (uint32_t j = 0; j < U; j++) { 52 | uint32_t d = 0; 53 | for (uint32_t k = 0; k < N; k++) { 54 | const uint32_t a = (i >> (k * bits)) & X; 55 | const uint32_t b = (j >> (k * bits)) & X; 56 | d += (a != b) ? 1 : 0; 57 | } 58 | row[j] = std::pair{j, d}; 59 | } 60 | std::sort(row, row + U, [](auto x, auto y) { 61 | if (x.second != y.second) { 62 | return x.second < y.second; 63 | } 64 | return x.first < y.first; 65 | }); 66 | for (uint32_t j = 0; j < U; j++) { 67 | LUT[i][j] = row[j].first; 68 | } 69 | for (uint32_t j = U; j < 256; j++) { 70 | LUT[i][j] = 0; 71 | } 72 | } 73 | } 74 | 75 | int main(int argc, char** argv) { 76 | #ifndef NDEBUG 77 | tfm::warnfln("The code is running in debug mode."); 78 | #endif 79 | 80 | tfm::printfln("static constexpr uint8_t HD_TABLE[8][256][256] = {"); 81 | for (uint32_t b = 1; b <= 8; b++) { 82 | compute_hdt(b); 83 | 84 | tfm::printfln("\t{ // b=%d", b); 85 | for (uint32_t i = 0; i < 256; i++) { 86 | tfm::printf("\t\t{"); 87 | for (uint32_t j = 0; j < 256; j++) { 88 | tfm::printf("%d, ", HDT[i][j]); 89 | } 90 | tfm::printfln("},"); 91 | } 92 | tfm::printfln("\t},"); 93 | } 94 | tfm::printfln("};"); 95 | 96 | tfm::printfln("static constexpr uint8_t LU_TABLE[8][256][256] = {"); 97 | for (uint32_t b = 1; b <= 8; b++) { 98 | compute_lut(b); 99 | 100 | tfm::printfln("\t{ // b=%d", b); 101 | for (uint32_t i = 0; i < 256; i++) { 102 | tfm::printf("\t\t{"); 103 | for (uint32_t j = 0; j < 256; j++) { 104 | tfm::printf("%d, ", LUT[i][j]); 105 | } 106 | tfm::printfln("},"); 107 | } 108 | tfm::printfln("\t},"); 109 | } 110 | tfm::printfln("};"); 111 | 112 | tfm::printfln("static constexpr int BP_TABLE[8][10] = {"); 113 | for (uint32_t b = 1; b <= 8; b++) { 114 | int n = 0; 115 | tfm::printf("\t{"); 116 | for (uint32_t k = 0; k < 9; k++) { 117 | tfm::printf("%d, ", n); 118 | n += DB[b - 1][k]; 119 | } 120 | tfm::printfln("%d}, // b=%d", n, b); 121 | } 122 | tfm::printfln("};"); 123 | 124 | return 0; 125 | } -------------------------------------------------------------------------------- /src/range_search_bin.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #ifdef ALGO_DYFT 10 | #include 11 | #endif 12 | #ifdef ALGO_HMS1V 13 | #include 14 | #endif 15 | 16 | using namespace dyft; 17 | 18 | static constexpr double ABORT_THRESHOLD_MS = 100.0; 19 | 20 | constexpr uint32_t SCALES[] = { 21 | 10, 100, 1'000, 10'000, 100'000, 1'000'000, 10'000'000, 100'000'000, 1'000'000'000, UINT32_MAX, 22 | }; 23 | 24 | cmd_line_parser::parser make_parser(int argc, char** argv) { 25 | cmd_line_parser::parser p(argc, argv); 26 | p.add("base_path", "input file path of database (in 64-bit binary)"); 27 | p.add("query_path", "input file path of queries (in 64-bit binary)"); 28 | p.add("result_dir", "output directory path of results", "-o", false); 29 | p.add("radius", "Hamming distance threshold", "-R", false); 30 | p.add("ints", "number of ints for each vector (32|64)", "-N", false); 31 | #if defined(ALGO_DYFT) 32 | p.add("algorithm", "algorithm (array|art|mart)", "-A", false); 33 | p.add("splitthr", "split threshold (0 means optimal assignment)", "-T", false); 34 | p.add("in_weight", "weigth of innode", "-W", false); 35 | p.add("blocks", "number of blocks (0 means reasonable number based on radius)", "-K", false); 36 | #endif 37 | ABORT_IF(!p.parse()); 38 | return p; 39 | } 40 | 41 | template 42 | int main_template(const cmd_line_parser::parser& p) { 43 | using vint_type = typename vcode_traits::vint_type; 44 | 45 | const auto base_path = p.get("base_path"); 46 | const auto query_path = p.get("query_path"); 47 | const auto result_dir = p.get("result_dir", "results"); 48 | const auto radius = p.get("radius", 2); 49 | #if defined(ALGO_DYFT) 50 | const auto algorithm = p.get("algorithm", "mart"); 51 | const auto splitthr = p.get("splitthr", 0); 52 | const auto in_weight = p.get("in_weight", 1.0); 53 | const auto blocks = p.get("blocks", 2); 54 | #endif 55 | 56 | #ifdef ALGO_LS 57 | tfm::printfln("algorithm: ls"); 58 | STATISTIC_TAG("algorithm", "ls"); 59 | #endif 60 | #ifdef ALGO_DYFT 61 | ABORT_IF((algorithm != "array") and (algorithm != "art") and (algorithm != "mart")); 62 | tfm::printfln("algorithm: %s", algorithm); 63 | tfm::printfln("splitthr: %d", splitthr); 64 | tfm::printfln("in_weight: %g", in_weight); 65 | tfm::printfln("blocks: %d", blocks); 66 | STATISTIC_TAG("algorithm", algorithm); 67 | STATISTIC_TAG("splitthr", splitthr); 68 | STATISTIC_TAG("in_weight", in_weight); 69 | STATISTIC_TAG("blocks", blocks); 70 | #endif 71 | #ifdef ALGO_HMS1V 72 | tfm::printfln("algorithm: hms1v"); 73 | STATISTIC_TAG("algorithm", "hms1v"); 74 | #endif 75 | 76 | tfm::printfln("ints: %d", N); 77 | tfm::printfln("radius: %d", radius); 78 | STATISTIC_TAG("ints", N); 79 | STATISTIC_TAG("radius", radius); 80 | 81 | const auto base_codes = load_vcodes_from_bin(base_path); 82 | const auto query_codes = load_vcodes_from_bin(query_path); 83 | tfm::printfln("base_path: %s", base_path); 84 | tfm::printfln("query_path: %s", query_path); 85 | STATISTIC_TAG("base_path", base_path); 86 | STATISTIC_TAG("query_path", query_path); 87 | 88 | const uint32_t base_size = base_codes->get_size(); 89 | const uint32_t query_size = query_codes->get_size(); 90 | tfm::printfln("base_size: %d", base_size); 91 | tfm::printfln("query_size: %d", query_size); 92 | STATISTIC_TAG("base_size", base_size); 93 | STATISTIC_TAG("query_size", query_size); 94 | 95 | std::vector test_scales; 96 | for (uint32_t i = 0; SCALES[i] < base_size; i++) { 97 | test_scales.push_back(SCALES[i]); 98 | } 99 | test_scales.push_back(base_size); 100 | 101 | uint32_t prev_size = 0; 102 | std::vector counts(query_size); 103 | std::vector verify_counts(query_size); 104 | 105 | #ifdef ALGO_LS 106 | double search_time_in_ms = 0.0; 107 | 108 | for (uint32_t test_size : test_scales) { 109 | tfm::printfln("# %d codes...", test_size); 110 | 111 | START_TIMER(Search); 112 | for (uint32_t qi = 0; qi < query_size; qi++) { 113 | const vint_type* qvcode = query_codes->access(qi); 114 | for (uint32_t bi = prev_size; bi < test_size; bi++) { 115 | const vint_type* bvcode = base_codes->access(bi); 116 | const int hamdist = vcode_tools::get_hamdist(qvcode, bvcode, 1, radius); 117 | if (hamdist <= radius) { 118 | counts[qi] += 1; 119 | } 120 | } 121 | } 122 | STOP_TIMER_V(Search); 123 | 124 | search_time_in_ms += GET_TIMER_MILLISEC(Search); 125 | const double ave_search_time_in_ms = search_time_in_ms / query_size; 126 | 127 | const double ave_num_results = get_average(counts); 128 | tfm::reportfln("average number of results: %g", ave_num_results); 129 | tfm::reportfln("average search time in ms: %g", ave_search_time_in_ms); 130 | 131 | STATISTIC_APPEND("search", {{"num_codes", test_size}, // 132 | {"ave_search_time_in_ms", ave_search_time_in_ms}, // 133 | {"ave_num_results", ave_num_results}}); 134 | 135 | if (ave_search_time_in_ms >= ABORT_THRESHOLD_MS) { 136 | tfm::warnfln("Abort search becasue the time exceeds %g ms", ABORT_THRESHOLD_MS); 137 | break; 138 | } 139 | 140 | prev_size = test_size; 141 | } 142 | 143 | const auto result_path = tfm::format("%s/range_search_bin_ls-%s-%s-%dN-%dR.json", // 144 | result_dir, normalize_filepath(base_path), normalize_filepath(query_path), // 145 | N, radius); 146 | #endif 147 | 148 | #ifdef ALGO_DYFT 149 | std::unique_ptr> index; 150 | 151 | if (algorithm == "array") { 152 | if (blocks <= 1) { 153 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 154 | } else { 155 | index = std::make_unique>>(base_codes.get(), radius, blocks, splitthr, in_weight); 156 | } 157 | } else if (algorithm == "art") { 158 | if (blocks <= 1) { 159 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 160 | } else { 161 | index = std::make_unique>>(base_codes.get(), radius, blocks, splitthr, in_weight); 162 | } 163 | } else { // algorithm == "mart" 164 | if (blocks <= 1) { 165 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 166 | } else { 167 | index = std::make_unique>>(base_codes.get(), radius, blocks, splitthr, in_weight); 168 | } 169 | } 170 | 171 | double insert_time_in_ms = 0.0; 172 | 173 | for (uint32_t test_size : test_scales) { 174 | tfm::printfln("# %d codes...", test_size); 175 | 176 | START_TIMER(Insert); 177 | for (uint32_t bi = prev_size; bi < test_size; bi++) { 178 | ABORT_IF_NE(bi, index->append()); 179 | } 180 | STOP_TIMER_V(Insert); 181 | 182 | insert_time_in_ms += GET_TIMER_MILLISEC(Insert); 183 | const double ave_insert_time_in_ms = insert_time_in_ms / test_size; 184 | 185 | START_TIMER(Search); 186 | for (uint32_t qi = 0; qi < query_size; qi++) { 187 | const vint_type* qvcode = query_codes->access(qi); 188 | 189 | uint32_t count = 0; 190 | uint32_t verify_count = 0; 191 | 192 | index->range_search(qvcode, [&](uint32_t bi) { 193 | const vint_type* bvcode = base_codes->access(bi); 194 | const int hamdist = vcode_tools::get_hamdist(qvcode, bvcode, 1, radius); 195 | if (hamdist <= radius) { 196 | count += 1; 197 | } 198 | verify_count += 1; 199 | }); 200 | counts[qi] = count; 201 | verify_counts[qi] = verify_count; 202 | } 203 | STOP_TIMER_V(Search); 204 | 205 | std::string selected = index->selects_ls() ? "LS" : "Trie"; 206 | tfm::reportfln("selected search: %s", selected); 207 | 208 | const double ave_search_time_in_ms = GET_TIMER_MILLISEC(Search) / query_size; 209 | tfm::reportfln("average insert time in ms: %g", ave_insert_time_in_ms); 210 | tfm::reportfln("average search time in ms: %g", ave_search_time_in_ms); 211 | 212 | const double ave_num_results = get_average(counts); 213 | const double ave_num_cadidates = get_average(verify_counts); 214 | tfm::reportfln("average number of results: %g", ave_num_results); 215 | tfm::reportfln("average number of candidates: %g", ave_num_cadidates); 216 | 217 | tfm::reportfln("num leaves: %d", index->get_leaves()); 218 | tfm::reportfln("num splits: %d", index->get_split_count()); 219 | 220 | STATISTIC_APPEND("search", {{"num_codes", test_size}, // 221 | {"num_leaves", index->get_leaves()}, // 222 | {"num_splits", index->get_split_count()}, // 223 | {"selected_search", selected}, // 224 | {"ave_insert_time_in_ms", ave_insert_time_in_ms}, // 225 | {"ave_search_time_in_ms", ave_search_time_in_ms}, // 226 | {"ave_num_results", ave_num_results}, // 227 | {"ave_num_cadidates", ave_num_cadidates}}); 228 | 229 | if (ave_search_time_in_ms >= ABORT_THRESHOLD_MS) { 230 | tfm::warnfln("Abort search becasue the time exceeds %g ms", ABORT_THRESHOLD_MS); 231 | break; 232 | } 233 | 234 | prev_size = test_size; 235 | } 236 | 237 | const auto result_path = tfm::format("%s/range_search_bin_%s-%s-%s-%dN-%dR-%dT-%gW-%dK.json", // 238 | result_dir, algorithm, // 239 | normalize_filepath(base_path), normalize_filepath(query_path), // 240 | N, radius, splitthr, in_weight * 100, blocks); 241 | #endif 242 | 243 | #ifdef ALGO_HMS1V 244 | double insert_time_in_ms = 0.0; 245 | hms1v_index index(base_codes.get(), radius); 246 | 247 | for (uint32_t test_size : test_scales) { 248 | tfm::printfln("# %d codes...", test_size); 249 | 250 | START_TIMER(Insert); 251 | for (uint32_t bi = prev_size; bi < test_size; bi++) { 252 | ABORT_IF_NE(bi, index.append()); 253 | } 254 | STOP_TIMER_V(Insert); 255 | 256 | insert_time_in_ms += GET_TIMER_MILLISEC(Insert); 257 | const double ave_insert_time_in_ms = insert_time_in_ms / test_size; 258 | 259 | START_TIMER(Search); 260 | for (uint32_t qi = 0; qi < query_size; qi++) { 261 | const vint_type* qvcode = query_codes->access(qi); 262 | 263 | uint32_t count = 0; 264 | uint32_t verify_count = 0; 265 | 266 | index.range_search(qvcode, [&](uint32_t bi) { 267 | const vint_type* bvcode = base_codes->access(bi); 268 | const int hamdist = vcode_tools::get_hamdist(qvcode, bvcode, 1, radius); 269 | if (hamdist <= radius) { 270 | count += 1; 271 | } 272 | verify_count += 1; 273 | }); 274 | counts[qi] = count; 275 | verify_counts[qi] = verify_count; 276 | } 277 | STOP_TIMER_V(Search); 278 | 279 | const double ave_search_time_in_ms = GET_TIMER_MILLISEC(Search) / query_size; 280 | tfm::reportfln("average insert time in ms: %g", ave_insert_time_in_ms); 281 | tfm::reportfln("average search time in ms: %g", ave_search_time_in_ms); 282 | 283 | const double ave_num_results = get_average(counts); 284 | const double ave_num_cadidates = get_average(verify_counts); 285 | tfm::reportfln("average number of results: %g", ave_num_results); 286 | tfm::reportfln("average number of candidates: %g", ave_num_cadidates); 287 | 288 | STATISTIC_APPEND("search", {{"num_codes", test_size}, // 289 | {"ave_insert_time_in_ms", ave_insert_time_in_ms}, // 290 | {"ave_search_time_in_ms", ave_search_time_in_ms}, // 291 | {"ave_num_results", ave_num_results}, // 292 | {"ave_num_cadidates", ave_num_cadidates}}); 293 | 294 | if (ave_search_time_in_ms >= ABORT_THRESHOLD_MS) { 295 | tfm::warnfln("Abort search becasue the time exceeds %g ms", ABORT_THRESHOLD_MS); 296 | break; 297 | } 298 | 299 | prev_size = test_size; 300 | } 301 | 302 | const auto result_path = tfm::format("%s/range_search_bin_hms1v-%s-%s-%dN-%dR.json", // 303 | result_dir, normalize_filepath(base_path), normalize_filepath(query_path), // 304 | N, radius); 305 | #endif 306 | 307 | make_directory(result_dir); 308 | STATISTIC_SAVE(result_path); 309 | tfm::printfln("wrote %s", result_path); 310 | 311 | return 0; 312 | } 313 | 314 | int main(int argc, char** argv) { 315 | #ifndef NDEBUG 316 | tfm::warnfln("The code is running in debug mode."); 317 | #endif 318 | 319 | auto p = make_parser(argc, argv); 320 | const auto ints = p.get("ints", 64); 321 | 322 | switch (ints) { 323 | case 32: 324 | return main_template<32>(p); 325 | case 64: 326 | return main_template<64>(p); 327 | default: 328 | break; 329 | } 330 | 331 | return 1; 332 | } -------------------------------------------------------------------------------- /src/range_search_int.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #ifdef ALGO_DYFT 10 | #include 11 | #endif 12 | #ifdef ALGO_HMS1DV 13 | #include 14 | #endif 15 | #ifdef ALGO_GV 16 | #include 17 | #endif 18 | 19 | using namespace dyft; 20 | 21 | static constexpr double ABORT_THRESHOLD_MS = 100.0; 22 | 23 | constexpr uint32_t SCALES[] = { 24 | 10, 100, 1'000, 10'000, 100'000, 1'000'000, 10'000'000, 100'000'000, 1'000'000'000, UINT32_MAX, 25 | }; 26 | 27 | cmd_line_parser::parser make_parser(int argc, char** argv) { 28 | cmd_line_parser::parser p(argc, argv); 29 | p.add("base_path", "input file path of database (in 64-bit binary)"); 30 | p.add("query_path", "input file path of queries (in 64-bit binary)"); 31 | p.add("result_dir", "output directory path of results", "-o", false); 32 | p.add("radius", "Hamming distance threshold", "-R", false); 33 | p.add("ints", "number of ints for each vector (32|64)", "-N", false); 34 | p.add("bits", "number of bits for each integer [1,8]", "-B", false); 35 | #if defined(ALGO_DYFT) 36 | p.add("algorithm", "algorithm (array|art|mart)", "-A", false); 37 | p.add("splitthr", "split threshold (0 means optimal assignment)", "-T", false); 38 | p.add("in_weight", "weigth of innode", "-W", false); 39 | p.add("blocks", "number of blocks (0 means reasonable number based on radius)", "-K", false); 40 | #endif 41 | ABORT_IF(!p.parse()); 42 | return p; 43 | } 44 | 45 | template 46 | int main_template(const cmd_line_parser::parser& p) { 47 | using vint_type = typename vcode_traits::vint_type; 48 | 49 | const auto base_path = p.get("base_path"); 50 | const auto query_path = p.get("query_path"); 51 | const auto result_dir = p.get("result_dir", "results"); 52 | const auto radius = p.get("radius", 2); 53 | const auto bits = p.get("bits", 4); 54 | #if defined(ALGO_DYFT) 55 | const auto algorithm = p.get("algorithm", "mart"); 56 | const auto splitthr = p.get("splitthr", 0); 57 | const auto in_weight = p.get("in_weight", 1.0); 58 | const auto blocks = p.get("blocks", 2); 59 | #endif 60 | 61 | #ifdef ALGO_LS 62 | tfm::printfln("algorithm: ls"); 63 | STATISTIC_TAG("algorithm", "ls"); 64 | #endif 65 | #ifdef ALGO_DYFT 66 | ABORT_IF((algorithm != "array") and (algorithm != "art") and (algorithm != "mart")); 67 | tfm::printfln("algorithm: %s", algorithm); 68 | tfm::printfln("splitthr: %d", splitthr); 69 | tfm::printfln("in_weight: %g", in_weight); 70 | tfm::printfln("blocks: %d", blocks); 71 | STATISTIC_TAG("algorithm", algorithm); 72 | STATISTIC_TAG("splitthr", splitthr); 73 | STATISTIC_TAG("in_weight", in_weight); 74 | STATISTIC_TAG("blocks", blocks); 75 | #endif 76 | #ifdef ALGO_HMS1DV 77 | tfm::printfln("algorithm: hms1dv"); 78 | STATISTIC_TAG("algorithm", "hms1dv"); 79 | #endif 80 | #ifdef ALGO_GV 81 | tfm::printfln("algorithm: gv"); 82 | STATISTIC_TAG("algorithm", "gv"); 83 | #endif 84 | 85 | tfm::printfln("ints: %d", N); 86 | tfm::printfln("bits: %d", bits); 87 | tfm::printfln("radius: %d", radius); 88 | STATISTIC_TAG("ints", N); 89 | STATISTIC_TAG("bits", bits); 90 | STATISTIC_TAG("radius", radius); 91 | 92 | const auto base_codes = load_vcodes_from_bvecs(base_path, bits); 93 | const auto query_codes = load_vcodes_from_bvecs(query_path, bits); 94 | tfm::printfln("base_path: %s", base_path); 95 | tfm::printfln("query_path: %s", query_path); 96 | STATISTIC_TAG("base_path", base_path); 97 | STATISTIC_TAG("query_path", query_path); 98 | 99 | const uint32_t base_size = base_codes->get_size(); 100 | const uint32_t query_size = query_codes->get_size(); 101 | tfm::printfln("base_size: %d", base_size); 102 | tfm::printfln("query_size: %d", query_size); 103 | STATISTIC_TAG("base_size", base_size); 104 | STATISTIC_TAG("query_size", query_size); 105 | 106 | std::vector test_scales; 107 | for (uint32_t i = 0; SCALES[i] < base_size; i++) { 108 | test_scales.push_back(SCALES[i]); 109 | } 110 | test_scales.push_back(base_size); 111 | 112 | uint32_t prev_size = 0; 113 | std::vector counts(query_size); 114 | std::vector verify_counts(query_size); 115 | 116 | #ifdef ALGO_LS 117 | double search_time_in_ms = 0.0; 118 | 119 | for (uint32_t test_size : test_scales) { 120 | tfm::printfln("# %d codes...", test_size); 121 | 122 | START_TIMER(Search); 123 | for (uint32_t qi = 0; qi < query_size; qi++) { 124 | const vint_type* qvcode = query_codes->access(qi); 125 | for (uint32_t bi = prev_size; bi < test_size; bi++) { 126 | const vint_type* bvcode = base_codes->access(bi); 127 | const int hamdist = vcode_tools::get_hamdist(qvcode, bvcode, bits, radius); 128 | if (hamdist <= radius) { 129 | counts[qi] += 1; 130 | } 131 | } 132 | } 133 | STOP_TIMER_V(Search); 134 | 135 | search_time_in_ms += GET_TIMER_MILLISEC(Search); 136 | const double ave_search_time_in_ms = search_time_in_ms / query_size; 137 | 138 | const double ave_num_results = get_average(counts); 139 | tfm::reportfln("average number of results: %g", ave_num_results); 140 | tfm::reportfln("average search time in ms: %g", ave_search_time_in_ms); 141 | 142 | STATISTIC_APPEND("search", {{"num_codes", test_size}, // 143 | {"ave_search_time_in_ms", ave_search_time_in_ms}, // 144 | {"ave_num_results", ave_num_results}}); 145 | 146 | if (ave_search_time_in_ms >= ABORT_THRESHOLD_MS) { 147 | tfm::warnfln("Abort search becasue the time exceeds %g ms", ABORT_THRESHOLD_MS); 148 | break; 149 | } 150 | 151 | prev_size = test_size; 152 | } 153 | 154 | const auto result_path = tfm::format("%s/range_search_int_ls-%s-%s-%dN-%dB-%dR.json", // 155 | result_dir, normalize_filepath(base_path), normalize_filepath(query_path), // 156 | N, bits, radius); 157 | #endif 158 | 159 | #ifdef ALGO_DYFT 160 | std::unique_ptr> index; 161 | 162 | if (algorithm == "array") { 163 | if (blocks == 1) { 164 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 165 | } else { 166 | index = std::make_unique>>( 167 | base_codes.get(), radius, blocks == 0 ? radius / 2 + 1 : blocks, splitthr, in_weight); 168 | } 169 | } else if (algorithm == "art") { 170 | if (blocks == 1) { 171 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 172 | } else { 173 | index = std::make_unique>>( 174 | base_codes.get(), radius, blocks == 0 ? radius / 2 + 1 : blocks, splitthr, in_weight); 175 | } 176 | } else { // algorithm == "mart" 177 | if (blocks == 1) { 178 | index = std::make_unique>(base_codes.get(), radius, splitthr, in_weight); 179 | } else { 180 | index = std::make_unique>>( 181 | base_codes.get(), radius, blocks == 0 ? radius / 2 + 1 : blocks, splitthr, in_weight); 182 | } 183 | } 184 | 185 | double insert_time_in_ms = 0.0; 186 | 187 | for (uint32_t test_size : test_scales) { 188 | tfm::printfln("# %d codes...", test_size); 189 | 190 | START_TIMER(Insert); 191 | for (uint32_t bi = prev_size; bi < test_size; bi++) { 192 | ABORT_IF_NE(bi, index->append()); 193 | } 194 | STOP_TIMER_V(Insert); 195 | 196 | insert_time_in_ms += GET_TIMER_MILLISEC(Insert); 197 | const double ave_insert_time_in_ms = insert_time_in_ms / test_size; 198 | 199 | START_TIMER(Search); 200 | for (uint32_t qi = 0; qi < query_size; qi++) { 201 | const vint_type* qvcode = query_codes->access(qi); 202 | 203 | uint32_t count = 0; 204 | uint32_t verify_count = 0; 205 | 206 | index->range_search(qvcode, [&](uint32_t bi) { 207 | const vint_type* bvcode = base_codes->access(bi); 208 | const int hamdist = vcode_tools::get_hamdist(qvcode, bvcode, bits, radius); 209 | if (hamdist <= radius) { 210 | count += 1; 211 | } 212 | verify_count += 1; 213 | }); 214 | 215 | counts[qi] = count; 216 | verify_counts[qi] = verify_count; 217 | } 218 | STOP_TIMER_V(Search); 219 | 220 | std::string selected = index->selects_ls() ? "LS" : "Trie"; 221 | tfm::reportfln("selected search: %s", selected); 222 | 223 | tfm::reportfln("ls_cost: %g", index->get_ls_cost()); 224 | tfm::reportfln("trie_cost: %g", index->get_trie_cost()); 225 | 226 | const double ave_search_time_in_ms = GET_TIMER_MILLISEC(Search) / query_size; 227 | tfm::reportfln("average insert time in ms: %g", ave_insert_time_in_ms); 228 | tfm::reportfln("average search time in ms: %g", ave_search_time_in_ms); 229 | 230 | const double ave_num_results = get_average(counts); 231 | const double ave_num_cadidates = get_average(verify_counts); 232 | tfm::reportfln("average number of results: %g", ave_num_results); 233 | tfm::reportfln("average number of candidates: %g", ave_num_cadidates); 234 | 235 | tfm::reportfln("num leaves: %d", index->get_leaves()); 236 | tfm::reportfln("num splits: %d", index->get_split_count()); 237 | 238 | STATISTIC_APPEND("search", {{"num_codes", test_size}, // 239 | {"num_leaves", index->get_leaves()}, // 240 | {"num_splits", index->get_split_count()}, // 241 | {"selected_search", selected}, // 242 | {"ave_insert_time_in_ms", ave_insert_time_in_ms}, // 243 | {"ave_search_time_in_ms", ave_search_time_in_ms}, // 244 | {"ave_num_results", ave_num_results}, // 245 | {"ave_num_cadidates", ave_num_cadidates}}); 246 | 247 | if (ave_search_time_in_ms >= ABORT_THRESHOLD_MS) { 248 | tfm::warnfln("Abort search becasue the time exceeds %g ms", ABORT_THRESHOLD_MS); 249 | break; 250 | } 251 | 252 | prev_size = test_size; 253 | } 254 | 255 | const auto result_path = tfm::format("%s/range_search_int_%s-%s-%s-%dN-%dB-%dR-%dT-%gW-%dK.json", // 256 | result_dir, algorithm, // 257 | normalize_filepath(base_path), normalize_filepath(query_path), // 258 | N, bits, radius, splitthr, in_weight * 100, blocks); 259 | #endif // defined(ALGO_DYFT) 260 | 261 | #if defined(ALGO_HMS1DV) 262 | try { 263 | double insert_time_in_ms = 0.0; 264 | hms1dv_index index(base_codes.get(), radius); 265 | 266 | for (uint32_t test_size : test_scales) { 267 | tfm::printfln("# %d codes...", test_size); 268 | 269 | START_TIMER(Insert); 270 | for (uint32_t bi = prev_size; bi < test_size; bi++) { 271 | ABORT_IF_NE(bi, index.append()); 272 | } 273 | STOP_TIMER_V(Insert); 274 | 275 | insert_time_in_ms += GET_TIMER_MILLISEC(Insert); 276 | const double ave_insert_time_in_ms = insert_time_in_ms / test_size; 277 | 278 | START_TIMER(Search); 279 | for (uint32_t qi = 0; qi < query_size; qi++) { 280 | const vint_type* qvcode = query_codes->access(qi); 281 | 282 | uint32_t count = 0; 283 | uint32_t verify_count = 0; 284 | 285 | index.range_search(qvcode, [&](uint32_t bi) { 286 | const vint_type* bvcode = base_codes->access(bi); 287 | const int hamdist = vcode_tools::get_hamdist(qvcode, bvcode, bits, radius); 288 | if (hamdist <= radius) { 289 | count += 1; 290 | } 291 | verify_count += 1; 292 | }); 293 | counts[qi] = count; 294 | verify_counts[qi] = verify_count; 295 | } 296 | STOP_TIMER_V(Search); 297 | 298 | const double ave_search_time_in_ms = GET_TIMER_MILLISEC(Search) / query_size; 299 | tfm::reportfln("average insert time in ms: %g", ave_insert_time_in_ms); 300 | tfm::reportfln("average search time in ms: %g", ave_search_time_in_ms); 301 | 302 | const double ave_num_results = get_average(counts); 303 | const double ave_num_cadidates = get_average(verify_counts); 304 | tfm::reportfln("average number of results: %g", ave_num_results); 305 | tfm::reportfln("average number of candidates: %g", ave_num_cadidates); 306 | 307 | STATISTIC_APPEND("search", {{"num_codes", test_size}, // 308 | {"ave_insert_time_in_ms", ave_insert_time_in_ms}, // 309 | {"ave_search_time_in_ms", ave_search_time_in_ms}, // 310 | {"ave_num_results", ave_num_results}, // 311 | {"ave_num_cadidates", ave_num_cadidates}}); 312 | 313 | if (ave_search_time_in_ms >= ABORT_THRESHOLD_MS) { 314 | tfm::warnfln("Abort search becasue the time exceeds %g ms", ABORT_THRESHOLD_MS); 315 | break; 316 | } 317 | 318 | prev_size = test_size; 319 | } 320 | } catch (const std::bad_alloc& e) { 321 | tfm::errorfln("Allocation failed: %s", e.what()); 322 | } 323 | 324 | const auto result_path = tfm::format("%s/range_search_int_hms1dv-%s-%s-%dN-%dB-%dR.json", // 325 | result_dir, normalize_filepath(base_path), normalize_filepath(query_path), // 326 | N, bits, radius); 327 | #endif // defined(ALGO_HMS1DV) 328 | 329 | #if defined(ALGO_GV) 330 | double insert_time_in_ms = 0.0; 331 | gv_index index(base_codes.get(), radius); 332 | 333 | for (uint32_t test_size : test_scales) { 334 | tfm::printfln("# %d codes...", test_size); 335 | 336 | START_TIMER(Insert); 337 | for (uint32_t bi = prev_size; bi < test_size; bi++) { 338 | ABORT_IF_NE(bi, index.append()); 339 | } 340 | STOP_TIMER_V(Insert); 341 | 342 | insert_time_in_ms += GET_TIMER_MILLISEC(Insert); 343 | const double ave_insert_time_in_ms = insert_time_in_ms / test_size; 344 | 345 | START_TIMER(Search); 346 | for (uint32_t qi = 0; qi < query_size; qi++) { 347 | const vint_type* qvcode = query_codes->access(qi); 348 | 349 | uint32_t count = 0; 350 | uint32_t verify_count = 0; 351 | 352 | index.range_search(qvcode, [&](uint32_t bi) { 353 | const vint_type* bvcode = base_codes->access(bi); 354 | const int hamdist = vcode_tools::get_hamdist(qvcode, bvcode, bits, radius); 355 | if (hamdist <= radius) { 356 | count += 1; 357 | } 358 | verify_count += 1; 359 | }); 360 | counts[qi] = count; 361 | verify_counts[qi] = verify_count; 362 | } 363 | STOP_TIMER_V(Search); 364 | 365 | const double ave_search_time_in_ms = GET_TIMER_MILLISEC(Search) / query_size; 366 | tfm::reportfln("average insert time in ms: %g", ave_insert_time_in_ms); 367 | tfm::reportfln("average search time in ms: %g", ave_search_time_in_ms); 368 | 369 | const double ave_num_results = get_average(counts); 370 | const double ave_num_cadidates = get_average(verify_counts); 371 | tfm::reportfln("average number of results: %g", ave_num_results); 372 | tfm::reportfln("average number of candidates: %g", ave_num_cadidates); 373 | 374 | STATISTIC_APPEND("search", {{"num_codes", test_size}, // 375 | {"ave_insert_time_in_ms", ave_insert_time_in_ms}, // 376 | {"ave_search_time_in_ms", ave_search_time_in_ms}, // 377 | {"ave_num_results", ave_num_results}, // 378 | {"ave_num_cadidates", ave_num_cadidates}}); 379 | 380 | if (ave_search_time_in_ms >= ABORT_THRESHOLD_MS) { 381 | tfm::warnfln("Abort search becasue the time exceeds %g ms", ABORT_THRESHOLD_MS); 382 | break; 383 | } 384 | 385 | prev_size = test_size; 386 | } 387 | 388 | const auto result_path = tfm::format("%s/range_search_int_gv-%s-%s-%dN-%dB-%dR.json", // 389 | result_dir, normalize_filepath(base_path), normalize_filepath(query_path), // 390 | N, bits, radius); 391 | #endif // defined(ALGO_GV) 392 | 393 | make_directory(result_dir); 394 | STATISTIC_SAVE(result_path); 395 | tfm::printfln("wrote %s", result_path); 396 | 397 | return 0; 398 | } 399 | 400 | int main(int argc, char** argv) { 401 | #ifndef NDEBUG 402 | tfm::warnfln("The code is running in debug mode."); 403 | #endif 404 | 405 | auto p = make_parser(argc, argv); 406 | const auto ints = p.get("ints", 64); 407 | 408 | switch (ints) { 409 | case 32: 410 | return main_template<32>(p); 411 | case 64: 412 | return main_template<64>(p); 413 | default: 414 | break; 415 | } 416 | 417 | return 1; 418 | } -------------------------------------------------------------------------------- /src/sample_bin.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | using namespace dyft; 8 | 9 | std::vector sample_codes(const std::vector& codes, uint32_t num, uint64_t seed) { 10 | splitmix64 engine(seed); 11 | std::vector sampled(num); 12 | for (uint32_t i = 0; i < num; ++i) { 13 | sampled[i] = codes[engine.next() % codes.size()]; 14 | } 15 | return sampled; 16 | } 17 | 18 | int main(int argc, char** argv) { 19 | #ifndef NDEBUG 20 | tfm::warnfln("The code is running in debug mode."); 21 | #endif 22 | 23 | cmd_line_parser::parser p(argc, argv); 24 | p.add("input_path", "input file path of simhashed codes"); 25 | p.add("output_path", "output file path of sampled simhashed codes"); 26 | p.add("num", "number of sample codes"); 27 | p.add("seed", "seed", "-s", false); 28 | ABORT_IF(!p.parse()); 29 | 30 | auto input_path = p.get("input_path"); 31 | auto output_path = p.get("output_path"); 32 | auto num = p.get("num"); 33 | auto seed = p.get("seed", 114514); 34 | 35 | const auto codes = load_vcodes_from_bin<64>(input_path); 36 | std::vector qcodes = sample_codes(codes->get_vcodes(), num, seed); 37 | 38 | auto ofs = make_ofstream(output_path); 39 | ofs.write(reinterpret_cast(qcodes.data()), sizeof(uint64_t) * num); 40 | 41 | return 0; 42 | } -------------------------------------------------------------------------------- /src/sample_int.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace dyft; 10 | 11 | std::vector sample_codes(const std::vector& codes, uint32_t num, uint64_t seed) { 12 | splitmix64 engine(seed); 13 | std::vector sampled(num); 14 | for (uint32_t i = 0; i < num; ++i) { 15 | sampled[i] = codes[engine.next() % codes.size()]; 16 | } 17 | return sampled; 18 | } 19 | 20 | int main(int argc, char** argv) { 21 | #ifndef NDEBUG 22 | tfm::warnfln("The code is running in debug mode."); 23 | #endif 24 | 25 | cmd_line_parser::parser p(argc, argv); 26 | p.add("input_path", "input file path of simhashed codes"); 27 | p.add("output_path", "output file path of sampled simhashed codes"); 28 | p.add("num", "number of sample codes"); 29 | p.add("seed", "seed", "-s", false); 30 | ABORT_IF(!p.parse()); 31 | 32 | auto input_path = p.get("input_path"); 33 | auto output_path = p.get("output_path"); 34 | auto num = p.get("num"); 35 | auto seed = p.get("seed", 114514); 36 | 37 | std::vector codes; 38 | uint32_t size = 0; 39 | uint32_t dim = 0; 40 | 41 | for (auto ifs = make_ifstream(input_path);;) { 42 | ifs.read(reinterpret_cast(&dim), sizeof(uint32_t)); 43 | if (ifs.eof()) { 44 | break; 45 | } 46 | ABORT_IF_LT(256, dim); 47 | 48 | uint8_t code[256]; 49 | ifs.read(reinterpret_cast(code), sizeof(uint8_t) * dim); 50 | std::copy(code, code + dim, std::back_inserter(codes)); 51 | 52 | size += 1; 53 | } 54 | 55 | std::default_random_engine engine(seed); 56 | std::uniform_int_distribution dist(0, size - 1); 57 | 58 | auto ofs = make_ofstream(output_path); 59 | for (size_t i = 0; i < num; ++i) { 60 | const uint64_t pos = dist(engine) * dim; 61 | ofs.write(reinterpret_cast(&dim), sizeof(uint32_t)); 62 | ofs.write(reinterpret_cast(codes.data() + pos), sizeof(uint8_t) * dim); 63 | } 64 | 65 | return 0; 66 | } -------------------------------------------------------------------------------- /src/simhash.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | using namespace dyft; 16 | 17 | const size_t REPETITIONS = 64; // i.e., output code length 18 | 19 | class fvecs_iterator { 20 | public: 21 | fvecs_iterator(const std::string& path, size_t dimension) : m_ifs(make_ifstream(path)), m_vec(dimension) {} 22 | 23 | const std::vector& next() { 24 | uint32_t dim = 0; 25 | m_ifs.read(reinterpret_cast(&dim), sizeof(uint32_t)); 26 | if (m_ifs.eof()) { 27 | m_vec.clear(); 28 | return m_vec; 29 | } 30 | ABORT_IF_NE(m_vec.size(), dim); 31 | m_ifs.read(reinterpret_cast(m_vec.data()), sizeof(float) * dim); 32 | return m_vec; 33 | } 34 | 35 | private: 36 | std::ifstream m_ifs; 37 | std::vector m_vec; 38 | }; 39 | 40 | class bvecs_iterator { 41 | public: 42 | bvecs_iterator(const std::string& path, size_t dimension) : m_ifs(make_ifstream(path)), m_vec(dimension) {} 43 | 44 | const std::vector& next() { 45 | uint32_t dim = 0; 46 | m_ifs.read(reinterpret_cast(&dim), sizeof(uint32_t)); 47 | if (m_ifs.eof()) { 48 | m_vec.clear(); 49 | return m_vec; 50 | } 51 | ABORT_IF_NE(m_vec.size(), dim); 52 | for (uint32_t i = 0; i < dim; ++i) { 53 | uint8_t v; 54 | m_ifs.read(reinterpret_cast(&v), sizeof(uint8_t)); 55 | m_vec[i] = static_cast(v); 56 | } 57 | return m_vec; 58 | } 59 | 60 | private: 61 | std::ifstream m_ifs; 62 | std::vector m_vec; 63 | }; 64 | 65 | class libsvm_iterator { 66 | public: 67 | libsvm_iterator(const std::string& path, size_t dimension) : m_ifs(make_ifstream(path)), m_vec(dimension) {} 68 | 69 | const std::vector& next() { 70 | if (!std::getline(m_ifs, m_row)) { 71 | m_vec.clear(); 72 | return m_vec; 73 | } 74 | 75 | std::fill(m_vec.begin(), m_vec.end(), 0.0); 76 | 77 | std::istringstream iss(m_row); 78 | iss.ignore(std::numeric_limits::max(), ' '); // ignore the label 79 | 80 | for (std::string taken; iss >> taken;) { 81 | const auto pos = taken.find(':'); 82 | ABORT_IF_EQ(pos, std::string::npos); 83 | 84 | const uint32_t id = static_cast(std::stoul(taken.substr(0, pos))) - 1; 85 | const float weight = std::stof(taken.substr(pos + 1)); 86 | 87 | ABORT_IF_LE(m_vec.size(), id); 88 | m_vec[id] = weight; 89 | } 90 | 91 | return m_vec; 92 | } 93 | 94 | private: 95 | std::string m_row; 96 | std::ifstream m_ifs; 97 | std::vector m_vec; 98 | }; 99 | 100 | class fasttext_iterator { 101 | public: 102 | fasttext_iterator(const std::string& path, size_t dimension) : m_ifs(make_ifstream(path)), m_vec(dimension) {} 103 | 104 | const std::vector& next() { 105 | if (!std::getline(m_ifs, m_row)) { 106 | m_vec.clear(); 107 | return m_vec; 108 | } 109 | 110 | std::fill(m_vec.begin(), m_vec.end(), 0.0); 111 | 112 | std::istringstream iss(m_row); 113 | for (std::string taken; iss >> taken;) { 114 | const auto pos = taken.find(':'); 115 | ABORT_IF_EQ(pos, std::string::npos); 116 | 117 | const uint32_t id = static_cast(std::stoul(taken.substr(0, pos))); 118 | const float weight = std::stof(taken.substr(pos + 1)); 119 | 120 | ABORT_IF_LE(m_vec.size(), id); 121 | m_vec[id] = weight; 122 | } 123 | 124 | return m_vec; 125 | } 126 | 127 | private: 128 | std::string m_row; 129 | std::ifstream m_ifs; 130 | std::vector m_vec; 131 | }; 132 | 133 | std::vector gen_random_matrix(size_t dimension, size_t repetitions, size_t seed) { 134 | std::vector random_matrix(dimension * repetitions); 135 | std::default_random_engine engine(seed); 136 | std::normal_distribution dist(0.0, 1.0); 137 | for (size_t i = 0; i < random_matrix.size(); ++i) { 138 | random_matrix[i] = dist(engine); 139 | } 140 | return random_matrix; 141 | } 142 | 143 | std::string get_ext(std::string path) { 144 | size_t pos = path.find_last_of(".") + 1; 145 | return path.substr(pos, path.size() - pos); 146 | } 147 | 148 | bool simhash(const float* data_vec, const float* random_vec, size_t dimension) { 149 | return std::inner_product(data_vec, data_vec + dimension, random_vec, 0.0) >= 0; 150 | } 151 | 152 | template 153 | int main_template(const cmd_line_parser::parser& p) { 154 | auto input_path = p.get("input_path"); 155 | auto output_path = p.get("output_path"); 156 | auto dimension = p.get("dimension"); 157 | auto seed = p.get("seed", 114514); 158 | 159 | Iterator data_itr(input_path, dimension); 160 | auto ofs = make_ofstream(tfm::format("%s.simhash%d", output_path, REPETITIONS)); 161 | 162 | const std::vector random_matrix = gen_random_matrix(dimension, REPETITIONS, seed); 163 | 164 | size_t num_codes = 0; 165 | for (;; ++num_codes) { 166 | if (num_codes % 100000 == 0) { 167 | tfm::printfln("%d vecs processed...", num_codes); 168 | } 169 | 170 | const std::vector& data_vec = data_itr.next(); 171 | if (data_vec.empty()) { 172 | break; 173 | } 174 | ABORT_IF_NE(data_vec.size(), dimension); 175 | 176 | uint64_t code = 0; 177 | for (size_t i = 0; i < REPETITIONS; ++i) { 178 | const float* random_vec = random_matrix.data() + (i * dimension); 179 | if (simhash(data_vec.data(), random_vec, dimension)) { 180 | code |= (1ULL << i); 181 | } 182 | } 183 | ofs.write(reinterpret_cast(&code), sizeof(uint64_t)); 184 | } 185 | 186 | tfm::printfln("Completed! The total number of vectors: %d", num_codes); 187 | return 0; 188 | } 189 | 190 | int main(int argc, char** argv) { 191 | #ifndef NDEBUG 192 | tfm::warnfln("The code is running in debug mode."); 193 | #endif 194 | 195 | cmd_line_parser::parser p(argc, argv); 196 | p.add("input_path", "input file path of vectors"); 197 | p.add("output_path", "output file path of hashed codes"); 198 | p.add("format", "format (fvecs|bvecs|libsvm|fasttext)"); 199 | p.add("dimension", "dimension of data vectors"); 200 | p.add("seed", "seed", "-s", false); 201 | ABORT_IF(!p.parse()); 202 | 203 | const auto format = p.get("format"); 204 | 205 | if (format == "fvecs") { 206 | return main_template(p); 207 | } else if (format == "bvecs") { 208 | return main_template(p); 209 | } else if (format == "libsvm") { 210 | return main_template(p); 211 | } else if (format == "fasttext") { 212 | return main_template(p); 213 | } 214 | 215 | tfm::errorfln("invalid format"); 216 | return 1; 217 | } --------------------------------------------------------------------------------