├── .clang-format ├── .github └── workflows │ └── main.yml ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── bench.cc ├── cmake_modules └── googletest.cmake ├── example.cc ├── include └── rs │ ├── builder.h │ ├── common.h │ ├── multi_map.h │ ├── radix_spline.h │ └── serializer.h └── test ├── multi_map_test.cc └── radix_spline_test.cc /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: Google 4 | AccessModifierOffset: -1 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveMacros: false 7 | AlignConsecutiveAssignments: false 8 | AlignConsecutiveDeclarations: false 9 | AlignEscapedNewlines: Left 10 | AlignOperands: true 11 | AlignTrailingComments: true 12 | AllowAllArgumentsOnNextLine: true 13 | AllowAllConstructorInitializersOnNextLine: true 14 | AllowAllParametersOfDeclarationOnNextLine: true 15 | AllowShortBlocksOnASingleLine: Never 16 | AllowShortCaseLabelsOnASingleLine: false 17 | AllowShortFunctionsOnASingleLine: All 18 | AllowShortLambdasOnASingleLine: All 19 | AllowShortIfStatementsOnASingleLine: WithoutElse 20 | AllowShortLoopsOnASingleLine: true 21 | AlwaysBreakAfterDefinitionReturnType: None 22 | AlwaysBreakAfterReturnType: None 23 | AlwaysBreakBeforeMultilineStrings: true 24 | AlwaysBreakTemplateDeclarations: Yes 25 | BinPackArguments: true 26 | BinPackParameters: true 27 | BraceWrapping: 28 | AfterCaseLabel: false 29 | AfterClass: false 30 | AfterControlStatement: false 31 | AfterEnum: false 32 | AfterFunction: false 33 | AfterNamespace: false 34 | AfterObjCDeclaration: false 35 | AfterStruct: false 36 | AfterUnion: false 37 | AfterExternBlock: false 38 | BeforeCatch: false 39 | BeforeElse: false 40 | IndentBraces: false 41 | SplitEmptyFunction: true 42 | SplitEmptyRecord: true 43 | SplitEmptyNamespace: true 44 | BreakBeforeBinaryOperators: None 45 | BreakBeforeBraces: Attach 46 | BreakBeforeInheritanceComma: false 47 | BreakInheritanceList: BeforeColon 48 | BreakBeforeTernaryOperators: true 49 | BreakConstructorInitializersBeforeComma: false 50 | BreakConstructorInitializers: BeforeColon 51 | BreakAfterJavaFieldAnnotations: false 52 | BreakStringLiterals: true 53 | ColumnLimit: 80 54 | CommentPragmas: '^ IWYU pragma:' 55 | CompactNamespaces: false 56 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 57 | ConstructorInitializerIndentWidth: 4 58 | ContinuationIndentWidth: 4 59 | Cpp11BracedListStyle: true 60 | DeriveLineEnding: true 61 | DerivePointerAlignment: false 62 | DisableFormat: false 63 | ExperimentalAutoDetectBinPacking: false 64 | FixNamespaceComments: true 65 | ForEachMacros: 66 | - foreach 67 | - Q_FOREACH 68 | - BOOST_FOREACH 69 | IncludeBlocks: Regroup 70 | IncludeCategories: 71 | - Regex: '^' 72 | Priority: 2 73 | SortPriority: 0 74 | - Regex: '^<.*\.h>' 75 | Priority: 1 76 | SortPriority: 0 77 | - Regex: '^<.*' 78 | Priority: 2 79 | SortPriority: 0 80 | - Regex: '.*' 81 | Priority: 3 82 | SortPriority: 0 83 | IncludeIsMainRegex: '([-_](test|unittest))?$' 84 | IncludeIsMainSourceRegex: '' 85 | IndentCaseLabels: true 86 | IndentGotoLabels: true 87 | IndentPPDirectives: None 88 | IndentWidth: 2 89 | IndentWrappedFunctionNames: false 90 | JavaScriptQuotes: Leave 91 | JavaScriptWrapImports: true 92 | KeepEmptyLinesAtTheStartOfBlocks: false 93 | MacroBlockBegin: '' 94 | MacroBlockEnd: '' 95 | MaxEmptyLinesToKeep: 1 96 | NamespaceIndentation: None 97 | ObjCBinPackProtocolList: Never 98 | ObjCBlockIndentWidth: 2 99 | ObjCSpaceAfterProperty: false 100 | ObjCSpaceBeforeProtocolList: true 101 | PenaltyBreakAssignment: 2 102 | PenaltyBreakBeforeFirstCallParameter: 1 103 | PenaltyBreakComment: 300 104 | PenaltyBreakFirstLessLess: 120 105 | PenaltyBreakString: 1000 106 | PenaltyBreakTemplateDeclaration: 10 107 | PenaltyExcessCharacter: 1000000 108 | PenaltyReturnTypeOnItsOwnLine: 200 109 | PointerAlignment: Left 110 | RawStringFormats: 111 | - Language: Cpp 112 | Delimiters: 113 | - cc 114 | - CC 115 | - cpp 116 | - Cpp 117 | - CPP 118 | - 'c++' 119 | - 'C++' 120 | CanonicalDelimiter: '' 121 | BasedOnStyle: google 122 | - Language: TextProto 123 | Delimiters: 124 | - pb 125 | - PB 126 | - proto 127 | - PROTO 128 | EnclosingFunctions: 129 | - EqualsProto 130 | - EquivToProto 131 | - PARSE_PARTIAL_TEXT_PROTO 132 | - PARSE_TEST_PROTO 133 | - PARSE_TEXT_PROTO 134 | - ParseTextOrDie 135 | - ParseTextProtoOrDie 136 | CanonicalDelimiter: '' 137 | BasedOnStyle: google 138 | ReflowComments: true 139 | SortIncludes: true 140 | SortUsingDeclarations: true 141 | SpaceAfterCStyleCast: false 142 | SpaceAfterLogicalNot: false 143 | SpaceAfterTemplateKeyword: true 144 | SpaceBeforeAssignmentOperators: true 145 | SpaceBeforeCpp11BracedList: false 146 | SpaceBeforeCtorInitializerColon: true 147 | SpaceBeforeInheritanceColon: true 148 | SpaceBeforeParens: ControlStatements 149 | SpaceBeforeRangeBasedForLoopColon: true 150 | SpaceInEmptyBlock: false 151 | SpaceInEmptyParentheses: false 152 | SpacesBeforeTrailingComments: 2 153 | SpacesInAngles: false 154 | SpacesInConditionalStatement: false 155 | SpacesInContainerLiterals: true 156 | SpacesInCStyleCastParentheses: false 157 | SpacesInParentheses: false 158 | SpacesInSquareBrackets: false 159 | SpaceBeforeSquareBrackets: false 160 | Standard: Auto 161 | StatementMacros: 162 | - Q_UNUSED 163 | - QT_REQUIRE_VERSION 164 | TabWidth: 8 165 | UseCRLF: false 166 | UseTab: Never 167 | ... 168 | 169 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | # Controls when the action will run. Triggers the workflow on push or pull request 4 | # events but only for the master branch 5 | on: push 6 | 7 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 8 | jobs: 9 | run_tests: 10 | runs-on: self-hosted # Runs on our locally hosted machine 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Build 15 | run: | 16 | mkdir -p build 17 | cd build 18 | cmake .. 19 | make 20 | - name: Tester 21 | run: | 22 | ./build/tester 23 | - name: Radix Spline 24 | run: | 25 | ./build/tester 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | cmake-build-debug/ 3 | cmake-build-release/ 4 | .idea/ 5 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | project(radixspline) 3 | 4 | set(CMAKE_CXX_STANDARD 14) 5 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native") 6 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g3 -Wall -Wextra") 7 | 8 | find_package(Threads REQUIRED) 9 | set(THREADS_PREFER_PTHREAD_FLAG ON) 10 | 11 | include("${CMAKE_SOURCE_DIR}/cmake_modules/googletest.cmake") 12 | 13 | include_directories( 14 | ${GTEST_INCLUDE_DIR} 15 | ${CMAKE_SOURCE_DIR} 16 | ) 17 | 18 | file(GLOB INCLUDE_H "include/rs/*.h") 19 | set(EXAMPLE_FILES example.cc) 20 | set(BENCH_FILES bench.cc) 21 | file(GLOB TEST_CC "test/*_test.cc") 22 | 23 | add_executable(example ${INCLUDE_H} ${EXAMPLE_FILES}) 24 | add_executable(bench ${INCLUDE_H} ${BENCH_FILES}) 25 | 26 | add_executable(tester ${TEST_CC}) 27 | target_link_libraries(tester gtest gtest_main Threads::Threads) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2020 Andreas Kipf, Alexander van Renen, Mihail Stoian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | RadixSpline: A Single-Pass Learned Index 2 | ==== 3 | 4 | ![](https://github.com/learnedsystems/RadixSpline/workflows/CI/badge.svg) 5 | 6 | A read-only learned index structure that can be built in a single pass over sorted data. Can be used as a drop-in replacement for ``std::multimap``. Currently limited to `uint32_t` and `uint64_t` data types. 7 | 8 | ## Build 9 | 10 | ``` 11 | mkdir -p build 12 | cd build 13 | cmake -DCMAKE_BUILD_TYPE=Release .. 14 | make 15 | ./example 16 | ./tester 17 | ``` 18 | 19 | ## Examples 20 | 21 | Using ``rs::Builder`` to index sorted data in one pass, without copying the data: 22 | 23 | ```c++ 24 | // Create random keys. 25 | vector keys(1e6); 26 | generate(keys.begin(), keys.end(), rand); 27 | keys.push_back(8128); 28 | sort(keys.begin(), keys.end()); 29 | 30 | // Build RadixSpline. 31 | uint64_t min = keys.front(); 32 | uint64_t max = keys.back(); 33 | rs::Builder rsb(min, max); 34 | for (const auto& key : keys) rsb.AddKey(key); 35 | rs::RadixSpline rs = rsb.Finalize(); 36 | 37 | // Search using RadixSpline. 38 | rs::SearchBound bound = rs.GetSearchBound(8128); 39 | cout << "The search key is in the range: [" 40 | << bound.begin << ", " << bound.end << ")" << endl; 41 | auto start = begin(keys) + bound.begin, last = begin(keys) + bound.end; 42 | cout << "The key is at position: " << std::lower_bound(start, last, 8128) - begin(keys) << endl; 43 | ``` 44 | 45 | Using ``rs::MultiMap`` to index unsorted data, which internally creates a sorted copy: 46 | 47 | ```c++ 48 | vector> data = {{1ull, 'a'}, 49 | {12ull, 'b'}, 50 | {7ull, 'c'}, // Unsorted. 51 | {42ull, 'd'}}; 52 | rs::MultiMap map(begin(data), end(data)); 53 | 54 | cout << "find(7): '" << map.find(7)->second << "'" << endl; 55 | cout << "lower_bound(3): '" << map.lower_bound(3)->second << "'" << endl; 56 | ``` 57 | 58 | ## Cite 59 | 60 | Please cite our [aiDM@SIGMOD 2020 paper](https://dl.acm.org/doi/10.1145/3401071.3401659) if you use this code in your own work: 61 | 62 | ``` 63 | @inproceedings{radixspline, 64 | author = {Andreas Kipf and 65 | Ryan Marcus and 66 | Alexander van Renen and 67 | Mihail Stoian and 68 | Alfons Kemper and 69 | Tim Kraska and 70 | Thomas Neumann}, 71 | title = {{RadixSpline}: a single-pass learned index}, 72 | booktitle = {Proceedings of the Third International Workshop on Exploiting Artificial 73 | Intelligence Techniques for Data Management, aiDM@SIGMOD 2020, Portland, 74 | Oregon, USA, June 19, 2020}, 75 | pages = {5:1--5:5}, 76 | year = {2020}, 77 | url = {https://doi.org/10.1145/3401071.3401659}, 78 | doi = {10.1145/3401071.3401659}, 79 | timestamp = {Mon, 08 Jun 2020 19:13:59 +0200}, 80 | biburl = {https://dblp.org/rec/conf/sigmod/KipfMRSKK020.bib}, 81 | bibsource = {dblp computer science bibliography, https://dblp.org} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /bench.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "include/rs/multi_map.h" 7 | 8 | using namespace std; 9 | 10 | namespace rs_manual_tuning { 11 | 12 | // Returns 13 | pair GetTuning(const string& data_filename, 14 | uint32_t size_scale) { 15 | assert(size_scale >= 1 && size_scale <= 10); 16 | 17 | string dataset = data_filename; 18 | 19 | // Cut the prefix of the filename 20 | size_t pos = dataset.find_last_of('/'); 21 | if (pos != string::npos) { 22 | dataset.erase(dataset.begin(), dataset.begin() + pos + 1); 23 | } 24 | 25 | using Configs = const vector>; 26 | 27 | if (dataset == "normal_200M_uint32") { 28 | Configs configs = {{10, 6}, {15, 1}, {16, 1}, {18, 1}, {20, 1}, 29 | {21, 1}, {24, 1}, {25, 1}, {26, 1}, {26, 1}}; 30 | return configs[10 - size_scale]; 31 | } 32 | 33 | if (dataset == "normal_200M_uint64") { 34 | Configs configs = {{14, 2}, {16, 1}, {16, 1}, {20, 1}, {22, 1}, 35 | {24, 1}, {26, 1}, {26, 1}, {28, 1}, {28, 1}}; 36 | return configs[10 - size_scale]; 37 | } 38 | 39 | if (dataset == "lognormal_200M_uint32") { 40 | Configs configs = {{12, 20}, {16, 3}, {16, 2}, {18, 1}, {20, 1}, 41 | {22, 1}, {24, 1}, {24, 1}, {26, 1}, {28, 1}}; 42 | return configs[10 - size_scale]; 43 | } 44 | 45 | if (dataset == "lognormal_200M_uint64") { 46 | Configs configs = {{12, 3}, {18, 1}, {18, 1}, {20, 1}, {22, 1}, 47 | {24, 1}, {26, 1}, {26, 1}, {28, 1}, {28, 1}}; 48 | return configs[10 - size_scale]; 49 | } 50 | 51 | if (dataset == "uniform_dense_200M_uint32") { 52 | Configs configs = {{4, 2}, {16, 2}, {18, 1}, {20, 1}, {20, 1}, 53 | {22, 2}, {24, 1}, {26, 3}, {26, 3}, {28, 2}}; 54 | return configs[10 - size_scale]; 55 | } 56 | 57 | if (dataset == "uniform_dense_200M_uint64") { 58 | Configs configs = {{4, 2}, {16, 1}, {16, 1}, {20, 1}, {22, 1}, 59 | {24, 1}, {24, 1}, {26, 1}, {28, 1}, {28, 1}}; 60 | return configs[10 - size_scale]; 61 | } 62 | 63 | if (dataset == "uniform_dense_200M_uint64") { 64 | Configs configs = {{4, 2}, {16, 1}, {16, 1}, {20, 1}, {22, 1}, 65 | {24, 1}, {24, 1}, {26, 1}, {28, 1}, {28, 1}}; 66 | return configs[10 - size_scale]; 67 | } 68 | 69 | if (dataset == "uniform_sparse_200M_uint32") { 70 | Configs configs = {{12, 220}, {14, 100}, {14, 80}, {16, 30}, {18, 20}, 71 | {20, 10}, {20, 8}, {20, 5}, {24, 3}, {26, 1}}; 72 | return configs[10 - size_scale]; 73 | } 74 | 75 | if (dataset == "uniform_sparse_200M_uint64") { 76 | Configs configs = {{12, 150}, {14, 70}, {16, 50}, {18, 20}, {20, 20}, 77 | {20, 9}, {20, 5}, {24, 3}, {26, 2}, {28, 1}}; 78 | return configs[10 - size_scale]; 79 | } 80 | 81 | // Books (or amazon in the paper) 82 | if (dataset == "books_200M_uint32") { 83 | Configs configs = {{14, 250}, {14, 250}, {16, 190}, {18, 80}, {18, 50}, 84 | {22, 20}, {22, 9}, {22, 8}, {24, 3}, {28, 2}}; 85 | return configs[10 - size_scale]; 86 | } 87 | 88 | if (dataset == "books_200M_uint64") { 89 | Configs configs = {{12, 380}, {16, 170}, {16, 110}, {20, 50}, {20, 30}, 90 | {22, 20}, {22, 10}, {24, 3}, {26, 3}, {28, 2}}; 91 | return configs[10 - size_scale]; 92 | } 93 | 94 | if (dataset == "books_400M_uint64") { 95 | Configs configs = {{16, 220}, {16, 220}, {18, 160}, {20, 60}, {20, 40}, 96 | {22, 20}, {22, 7}, {26, 3}, {28, 2}, {28, 1}}; 97 | return configs[10 - size_scale]; 98 | } 99 | 100 | if (dataset == "books_600M_uint64") { 101 | Configs configs = {{18, 330}, {18, 330}, {18, 190}, {20, 70}, {22, 50}, 102 | {22, 20}, {24, 7}, {26, 3}, {28, 2}, {28, 1}}; 103 | return configs[10 - size_scale]; 104 | } 105 | 106 | if (dataset == "books_800M_uint64") { 107 | Configs configs = {{18, 320}, {18, 320}, {18, 200}, {22, 80}, {22, 60}, 108 | {22, 20}, {24, 9}, {26, 3}, {28, 3}, {28, 3}}; 109 | return configs[10 - size_scale]; 110 | } 111 | 112 | // Facebook 113 | if (dataset == "fb_200M_uint64") { 114 | Configs configs = {{8, 140}, {8, 140}, {8, 140}, {8, 140}, {10, 90}, 115 | {22, 90}, {24, 70}, {26, 80}, {26, 7}, {28, 80}}; 116 | return configs[10 - size_scale]; 117 | } 118 | 119 | // OSM 120 | if (dataset == "osm_cellids_200M_uint64") { 121 | Configs configs = {{20, 160}, {20, 160}, {20, 160}, {20, 160}, {20, 80}, 122 | {24, 40}, {24, 20}, {26, 8}, {26, 3}, {28, 2}}; 123 | return configs[10 - size_scale]; 124 | } 125 | 126 | if (dataset == "osm_cellids_400M_uint64") { 127 | Configs configs = {{20, 190}, {20, 190}, {20, 190}, {20, 190}, {22, 80}, 128 | {24, 20}, {26, 20}, {26, 10}, {28, 6}, {28, 2}}; 129 | return configs[10 - size_scale]; 130 | } 131 | 132 | if (dataset == "osm_cellids_600M_uint64") { 133 | Configs configs = {{20, 190}, {20, 190}, {20, 190}, {22, 180}, {22, 100}, 134 | {24, 20}, {26, 20}, {28, 7}, {28, 5}, {28, 2}}; 135 | return configs[10 - size_scale]; 136 | } 137 | 138 | if (dataset == "osm_cellids_800M_uint64") { 139 | Configs configs = {{22, 190}, {22, 190}, {22, 190}, {22, 190}, {24, 190}, 140 | {26, 30}, {26, 20}, {28, 7}, {28, 5}, {28, 1}}; 141 | return configs[10 - size_scale]; 142 | } 143 | 144 | // Wiki 145 | if (dataset == "wiki_ts_200M_uint64") { 146 | Configs configs = {{14, 100}, {14, 100}, {16, 60}, {18, 20}, {20, 20}, 147 | {20, 9}, {20, 5}, {22, 3}, {26, 2}, {26, 1}}; 148 | return configs[10 - size_scale]; 149 | } 150 | 151 | cerr << "No tuning config for this dataset" << endl; 152 | throw; 153 | } 154 | } // namespace rs_manual_tuning 155 | 156 | namespace util { 157 | 158 | // Loads values from binary file into vector. 159 | template 160 | static vector load_data(const string& filename, bool print = true) { 161 | vector data; 162 | ifstream in(filename, ios::binary); 163 | if (!in.is_open()) { 164 | cerr << "unable to open " << filename << endl; 165 | exit(EXIT_FAILURE); 166 | } 167 | // Read size. 168 | uint64_t size; 169 | in.read(reinterpret_cast(&size), sizeof(uint64_t)); 170 | data.resize(size); 171 | // Read values. 172 | in.read(reinterpret_cast(data.data()), size * sizeof(T)); 173 | 174 | return data; 175 | } 176 | 177 | // Generates deterministic values for keys. 178 | template 179 | static vector> add_values(const vector& keys) { 180 | vector> result; 181 | result.reserve(keys.size()); 182 | 183 | for (uint64_t i = 0; i < keys.size(); ++i) { 184 | pair row; 185 | row.first = keys[i]; 186 | row.second = i; 187 | 188 | result.push_back(row); 189 | } 190 | return result; 191 | } 192 | 193 | } // namespace util 194 | 195 | namespace { 196 | 197 | template 198 | class NonOwningMultiMap { 199 | public: 200 | using element_type = pair; 201 | 202 | NonOwningMultiMap(const vector& elements, 203 | size_t num_radix_bits = 18, size_t max_error = 32) 204 | : data_(elements) { 205 | assert(elements.size() > 0); 206 | 207 | // Create spline builder. 208 | const auto min_key = data_.front().first; 209 | const auto max_key = data_.back().first; 210 | rs::Builder rsb(min_key, max_key, num_radix_bits, max_error); 211 | 212 | // Build the radix spline. 213 | for (const auto& iter : data_) { 214 | rsb.AddKey(iter.first); 215 | } 216 | rs_ = rsb.Finalize(); 217 | } 218 | 219 | typename vector::const_iterator lower_bound(KeyType key) const { 220 | rs::SearchBound bound = rs_.GetSearchBound(key); 221 | return ::lower_bound(data_.begin() + bound.begin, data_.begin() + bound.end, 222 | key, [](const element_type& lhs, const KeyType& rhs) { 223 | return lhs.first < rhs; 224 | }); 225 | } 226 | 227 | uint64_t sum_up(KeyType key) const { 228 | uint64_t result = 0; 229 | auto iter = lower_bound(key); 230 | while (iter != data_.end() && iter->first == key) { 231 | result += iter->second; 232 | ++iter; 233 | } 234 | return result; 235 | } 236 | 237 | size_t GetSizeInByte() const { return rs_.GetSize(); } 238 | 239 | private: 240 | const vector& data_; 241 | rs::RadixSpline rs_; 242 | }; 243 | 244 | template 245 | struct Lookup { 246 | KeyType key; 247 | uint64_t value; 248 | }; 249 | 250 | template 251 | void Run(const string& data_file, const string lookup_file) { 252 | // Load data 253 | vector keys = util::load_data(data_file); 254 | vector> elements = util::add_values(keys); 255 | vector> lookups = 256 | util::load_data>(lookup_file); 257 | 258 | for (uint32_t size_config = 1; size_config <= 10; ++size_config) { 259 | // Get the config for tuning 260 | auto tuning = rs_manual_tuning::GetTuning(data_file, size_config); 261 | 262 | // Build RS 263 | auto build_begin = chrono::high_resolution_clock::now(); 264 | NonOwningMultiMap map(elements, tuning.first, 265 | tuning.second); 266 | auto build_end = chrono::high_resolution_clock::now(); 267 | uint64_t build_ns = 268 | chrono::duration_cast(build_end - build_begin) 269 | .count(); 270 | 271 | // Run queries 272 | auto lookup_begin = chrono::high_resolution_clock::now(); 273 | for (const Lookup& lookup_iter : lookups) { 274 | uint64_t sum = map.sum_up(lookup_iter.key); 275 | if (sum != lookup_iter.value) { 276 | cerr << "wrong result!" << endl; 277 | throw "error"; 278 | } 279 | } 280 | auto lookup_end = chrono::high_resolution_clock::now(); 281 | uint64_t lookup_ns = 282 | chrono::duration_cast(lookup_end - lookup_begin) 283 | .count(); 284 | 285 | cout << "RESULT:" 286 | << " data_file: " << data_file << " lookup_file: " << lookup_file 287 | << " radix_bit_count: " << tuning.first 288 | << " spline_error: " << tuning.second 289 | << " size_config: " << size_config 290 | << " used_memory[MB]: " << (map.GetSizeInByte() / 1000) / 1000.0 291 | << " build_time[s]: " << (build_ns / 1000 / 1000) / 1000.0 292 | << " ns/lookup: " << lookup_ns / lookups.size() << endl; 293 | } 294 | } 295 | 296 | } // namespace 297 | 298 | int main(int argc, char** argv) { 299 | if (argc != 3) { 300 | cerr << "usage: " << argv[0] << " " << endl; 301 | throw; 302 | } 303 | const string data_file = argv[1]; 304 | const string lookup_file = argv[2]; 305 | 306 | if (data_file.find("32") != string::npos) { 307 | Run(data_file, lookup_file); 308 | } else { 309 | Run(data_file, lookup_file); 310 | } 311 | 312 | return 0; 313 | } 314 | -------------------------------------------------------------------------------- /cmake_modules/googletest.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | find_package(Git REQUIRED) 3 | find_package(Threads REQUIRED) 4 | 5 | # Get googletest 6 | ExternalProject_Add( 7 | gtest_src 8 | PREFIX "extern/gtest" 9 | GIT_REPOSITORY "https://github.com/google/googletest.git" 10 | GIT_TAG "release-1.10.0" 11 | TIMEOUT 10 12 | CMAKE_ARGS 13 | -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/extern/gtest 14 | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} 15 | -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} 16 | -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} 17 | UPDATE_COMMAND "" 18 | ) 19 | 20 | # Prepare gtest 21 | ExternalProject_Get_Property(gtest_src install_dir) 22 | set(GTEST_INCLUDE_DIR ${install_dir}/include) 23 | set(GTEST_LIBRARY_PATH ${install_dir}/lib/libgtest.a) 24 | set(GTEST_MAIN_LIBRARY_PATH ${install_dir}/lib/libgtest_main.a) 25 | add_library(gtest UNKNOWN IMPORTED) 26 | add_library(gtest_main UNKNOWN IMPORTED) 27 | set_property(TARGET gtest PROPERTY IMPORTED_LOCATION ${GTEST_LIBRARY_PATH}) 28 | set_property(TARGET gtest_main PROPERTY IMPORTED_LOCATION ${GTEST_MAIN_LIBRARY_PATH}) 29 | 30 | # Dependencies 31 | add_dependencies(gtest gtest_src) 32 | add_dependencies(gtest_main gtest_src) -------------------------------------------------------------------------------- /example.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "include/rs/multi_map.h" 4 | 5 | using namespace std; 6 | 7 | void RadixSplineExample() { 8 | // Create random keys. 9 | vector keys(1e6); 10 | generate(keys.begin(), keys.end(), rand); 11 | keys.push_back(8128); 12 | sort(keys.begin(), keys.end()); 13 | 14 | // Build RadixSpline. 15 | uint64_t min = keys.front(); 16 | uint64_t max = keys.back(); 17 | rs::Builder rsb(min, max); 18 | for (const auto& key : keys) rsb.AddKey(key); 19 | rs::RadixSpline rs = rsb.Finalize(); 20 | 21 | // Search using RadixSpline. 22 | rs::SearchBound bound = rs.GetSearchBound(8128); 23 | cout << "The search key is in the range: [" << bound.begin << ", " 24 | << bound.end << ")" << endl; 25 | auto start = begin(keys) + bound.begin, last = begin(keys) + bound.end; 26 | cout << "The key is at position: " 27 | << std::lower_bound(start, last, 8128) - begin(keys) << endl; 28 | } 29 | 30 | void MultiMapExample() { 31 | vector> data = {{1ull, 'a'}, 32 | {12ull, 'b'}, 33 | {7ull, 'c'}, // Unsorted. 34 | {42ull, 'd'}}; 35 | rs::MultiMap map(begin(data), end(data)); 36 | 37 | cout << "find(7): '" << map.find(7)->second << "'" << endl; 38 | cout << "lower_bound(3): '" << map.lower_bound(3)->second << "'" << endl; 39 | } 40 | 41 | int main(int argc, char** argv) { 42 | RadixSplineExample(); 43 | MultiMapExample(); 44 | 45 | return 0; 46 | } 47 | -------------------------------------------------------------------------------- /include/rs/builder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "common.h" 8 | #include "radix_spline.h" 9 | 10 | namespace rs { 11 | 12 | // Allows building a `RadixSpline` in a single pass over sorted data. 13 | template 14 | class Builder { 15 | public: 16 | Builder(KeyType min_key, KeyType max_key, size_t num_radix_bits = 18, 17 | size_t max_error = 32) 18 | : min_key_(min_key), 19 | max_key_(max_key), 20 | num_radix_bits_(num_radix_bits), 21 | num_shift_bits_(GetNumShiftBits(max_key - min_key, num_radix_bits)), 22 | max_error_(max_error), 23 | curr_num_keys_(0), 24 | curr_num_distinct_keys_(0), 25 | prev_key_(min_key), 26 | prev_position_(0), 27 | prev_prefix_(0) { 28 | // Initialize radix table, needs to contain all prefixes up to the largest 29 | // key + 1. 30 | const uint32_t max_prefix = (max_key - min_key) >> num_shift_bits_; 31 | radix_table_.resize(max_prefix + 2, 0); 32 | } 33 | 34 | // Adds a key. Assumes that keys are stored in a dense array. 35 | void AddKey(KeyType key) { 36 | if (curr_num_keys_ == 0) { 37 | AddKey(key, /*position=*/0); 38 | return; 39 | } 40 | AddKey(key, prev_position_ + 1); 41 | } 42 | 43 | // Finalizes the construction and returns a read-only `RadixSpline`. 44 | RadixSpline Finalize() { 45 | // Last key needs to be equal to `max_key_`. 46 | assert(curr_num_keys_ == 0 || prev_key_ == max_key_); 47 | 48 | // Ensure that `prev_key_` (== `max_key_`) is last key on spline. 49 | if (curr_num_keys_ > 0 && spline_points_.back().x != prev_key_) 50 | AddKeyToSpline(prev_key_, prev_position_); 51 | 52 | // Maybe even size the radix based on max key right from the start 53 | FinalizeRadixTable(); 54 | 55 | return RadixSpline( 56 | min_key_, max_key_, curr_num_keys_, num_radix_bits_, num_shift_bits_, 57 | max_error_, std::move(radix_table_), std::move(spline_points_)); 58 | } 59 | 60 | private: 61 | // Returns the number of shift bits based on the `diff` between the largest 62 | // and the smallest key. KeyType == uint32_t. 63 | static size_t GetNumShiftBits(uint32_t diff, size_t num_radix_bits) { 64 | const uint32_t clz = __builtin_clz(diff); 65 | if ((32 - clz) < num_radix_bits) return 0; 66 | return 32 - num_radix_bits - clz; 67 | } 68 | // KeyType == uint64_t. 69 | static size_t GetNumShiftBits(uint64_t diff, size_t num_radix_bits) { 70 | const uint32_t clzl = __builtin_clzl(diff); 71 | if ((64 - clzl) < num_radix_bits) return 0; 72 | return 64 - num_radix_bits - clzl; 73 | } 74 | 75 | void AddKey(KeyType key, size_t position) { 76 | assert(key >= min_key_ && key <= max_key_); 77 | // Keys need to be monotonically increasing. 78 | assert(key >= prev_key_); 79 | // Positions need to be strictly monotonically increasing. 80 | assert(position == 0 || position > prev_position_); 81 | 82 | PossiblyAddKeyToSpline(key, position); 83 | 84 | ++curr_num_keys_; 85 | prev_key_ = key; 86 | prev_position_ = position; 87 | } 88 | 89 | void AddKeyToSpline(KeyType key, double position) { 90 | spline_points_.push_back({key, position}); 91 | PossiblyAddKeyToRadixTable(key); 92 | } 93 | 94 | enum Orientation { Collinear, CW, CCW }; 95 | static constexpr double precision = std::numeric_limits::epsilon(); 96 | 97 | static Orientation ComputeOrientation(const double dx1, const double dy1, 98 | const double dx2, const double dy2) { 99 | const double expr = std::fma(dy1, dx2, -std::fma(dy2, dx1, 0)); 100 | if (expr > precision) 101 | return Orientation::CW; 102 | else if (expr < -precision) 103 | return Orientation::CCW; 104 | return Orientation::Collinear; 105 | }; 106 | 107 | void SetUpperLimit(KeyType key, double position) { 108 | upper_limit_ = {key, position}; 109 | } 110 | void SetLowerLimit(KeyType key, double position) { 111 | lower_limit_ = {key, position}; 112 | } 113 | void RememberPreviousCDFPoint(KeyType key, double position) { 114 | prev_point_ = {key, position}; 115 | } 116 | 117 | // Implementation is based on `GreedySplineCorridor` from: 118 | // T. Neumann and S. Michel. Smooth interpolating histograms with error 119 | // guarantees. [BNCOD'08] 120 | void PossiblyAddKeyToSpline(KeyType key, double position) { 121 | if (curr_num_keys_ == 0) { 122 | // Add first CDF point to spline. 123 | AddKeyToSpline(key, position); 124 | ++curr_num_distinct_keys_; 125 | RememberPreviousCDFPoint(key, position); 126 | return; 127 | } 128 | 129 | if (key == prev_key_) { 130 | // No new CDF point if the key didn't change. 131 | return; 132 | } 133 | 134 | // New CDF point. 135 | ++curr_num_distinct_keys_; 136 | 137 | if (curr_num_distinct_keys_ == 2) { 138 | // Initialize `upper_limit_` and `lower_limit_` using the second CDF 139 | // point. 140 | SetUpperLimit(key, position + max_error_); 141 | SetLowerLimit(key, (position < max_error_) ? 0 : position - max_error_); 142 | RememberPreviousCDFPoint(key, position); 143 | return; 144 | } 145 | 146 | // `B` in algorithm. 147 | const Coord& last = spline_points_.back(); 148 | 149 | // Compute current `upper_y` and `lower_y`. 150 | const double upper_y = position + max_error_; 151 | const double lower_y = (position < max_error_) ? 0 : position - max_error_; 152 | 153 | // Compute differences. 154 | assert(upper_limit_.x >= last.x); 155 | assert(lower_limit_.x >= last.x); 156 | assert(key >= last.x); 157 | const double upper_limit_x_diff = upper_limit_.x - last.x; 158 | const double lower_limit_x_diff = lower_limit_.x - last.x; 159 | const double x_diff = key - last.x; 160 | 161 | assert(upper_limit_.y >= last.y); 162 | assert(position >= last.y); 163 | const double upper_limit_y_diff = upper_limit_.y - last.y; 164 | const double lower_limit_y_diff = lower_limit_.y - last.y; 165 | const double y_diff = position - last.y; 166 | 167 | // `prev_point_` is the previous point on the CDF and the next candidate to 168 | // be added to the spline. Hence, it should be different from the `last` 169 | // point on the spline. 170 | assert(prev_point_.x != last.x); 171 | 172 | // Do we cut the error corridor? 173 | if ((ComputeOrientation(upper_limit_x_diff, upper_limit_y_diff, x_diff, 174 | y_diff) != Orientation::CW) || 175 | (ComputeOrientation(lower_limit_x_diff, lower_limit_y_diff, x_diff, 176 | y_diff) != Orientation::CCW)) { 177 | // Add previous CDF point to spline. 178 | AddKeyToSpline(prev_point_.x, prev_point_.y); 179 | 180 | // Update limits. 181 | SetUpperLimit(key, upper_y); 182 | SetLowerLimit(key, lower_y); 183 | } else { 184 | assert(upper_y >= last.y); 185 | const double upper_y_diff = upper_y - last.y; 186 | if (ComputeOrientation(upper_limit_x_diff, upper_limit_y_diff, x_diff, 187 | upper_y_diff) == Orientation::CW) { 188 | SetUpperLimit(key, upper_y); 189 | } 190 | 191 | const double lower_y_diff = lower_y - last.y; 192 | if (ComputeOrientation(lower_limit_x_diff, lower_limit_y_diff, x_diff, 193 | lower_y_diff) == Orientation::CCW) { 194 | SetLowerLimit(key, lower_y); 195 | } 196 | } 197 | 198 | RememberPreviousCDFPoint(key, position); 199 | } 200 | 201 | void PossiblyAddKeyToRadixTable(KeyType key) { 202 | const KeyType curr_prefix = (key - min_key_) >> num_shift_bits_; 203 | if (curr_prefix != prev_prefix_) { 204 | const uint32_t curr_index = spline_points_.size() - 1; 205 | for (KeyType prefix = prev_prefix_ + 1; prefix <= curr_prefix; ++prefix) 206 | radix_table_[prefix] = curr_index; 207 | prev_prefix_ = curr_prefix; 208 | } 209 | } 210 | 211 | void FinalizeRadixTable() { 212 | ++prev_prefix_; 213 | const uint32_t num_spline_points = spline_points_.size(); 214 | for (; prev_prefix_ < radix_table_.size(); ++prev_prefix_) 215 | radix_table_[prev_prefix_] = num_spline_points; 216 | } 217 | 218 | const KeyType min_key_; 219 | const KeyType max_key_; 220 | const size_t num_radix_bits_; 221 | const size_t num_shift_bits_; 222 | const size_t max_error_; 223 | 224 | std::vector radix_table_; 225 | std::vector> spline_points_; 226 | 227 | size_t curr_num_keys_; 228 | size_t curr_num_distinct_keys_; 229 | KeyType prev_key_; 230 | size_t prev_position_; 231 | KeyType prev_prefix_; 232 | 233 | // Current upper and lower limits on the error corridor of the spline. 234 | Coord upper_limit_; 235 | Coord lower_limit_; 236 | 237 | // Previous CDF point. 238 | Coord prev_point_; 239 | }; 240 | 241 | } // namespace rs -------------------------------------------------------------------------------- /include/rs/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace rs { 7 | 8 | // A CDF coordinate. 9 | template 10 | struct Coord { 11 | KeyType x; 12 | double y; 13 | }; 14 | 15 | struct SearchBound { 16 | size_t begin; 17 | size_t end; // Exclusive. 18 | }; 19 | 20 | } // namespace rs 21 | -------------------------------------------------------------------------------- /include/rs/multi_map.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "builder.h" 8 | #include "radix_spline.h" 9 | 10 | namespace rs { 11 | 12 | // A drop-in replacement for std::multimap. Internally creates a sorted copy of 13 | // the data. 14 | template 15 | class MultiMap { 16 | public: 17 | // Member type definitions. 18 | using key_type = KeyType; 19 | using mapped_type = ValueType; 20 | using value_type = std::pair; 21 | using size_type = std::size_t; 22 | using iterator = typename std::vector::iterator; 23 | using const_iterator = typename std::vector::const_iterator; 24 | 25 | // Constructor, creates a copy of the data. 26 | template 27 | MultiMap(BidirIt first, BidirIt last, size_t num_radix_bits = 18, 28 | size_t max_error = 32); 29 | 30 | // Lookup functions, like in std::multimap. 31 | const_iterator find(KeyType key) const; 32 | const_iterator lower_bound(KeyType key) const; 33 | 34 | // Iterators. 35 | const_iterator begin() const { return data_.begin(); } 36 | const_iterator end() const { return data_.end(); } 37 | 38 | // Size. 39 | std::size_t size() const { return data_.size(); } 40 | 41 | private: 42 | std::vector data_; 43 | RadixSpline rs_; 44 | }; 45 | 46 | template 47 | template 48 | MultiMap::MultiMap(BidirIt first, BidirIt last, 49 | size_t num_radix_bits, 50 | size_t max_error) { 51 | // Empty spline. 52 | if (first == last) { 53 | rs::Builder rsb(std::numeric_limits::min(), 54 | std::numeric_limits::max(), 55 | num_radix_bits, max_error); 56 | rs_ = rsb.Finalize(); 57 | return; 58 | } 59 | 60 | // Copy data and check if sorted. 61 | bool is_sorted = true; 62 | KeyType previous_key = first->first; 63 | for (auto current = first; current != last; ++current) { 64 | is_sorted &= current->first >= previous_key; 65 | previous_key = current->first; 66 | data_.push_back(*current); 67 | } 68 | 69 | // Sort if necessary. 70 | if (!is_sorted) { 71 | std::sort(data_.begin(), data_.end(), 72 | [](const value_type& lhs, const value_type& rhs) { 73 | return lhs.first < rhs.first; 74 | }); 75 | } 76 | 77 | // Create spline builder. 78 | const auto min_key = data_.front().first; 79 | const auto max_key = data_.back().first; 80 | rs::Builder rsb(min_key, max_key, num_radix_bits, max_error); 81 | 82 | // Build the radix spline. 83 | for (const auto& iter : data_) { 84 | rsb.AddKey(iter.first); 85 | } 86 | rs_ = rsb.Finalize(); 87 | } 88 | 89 | template 90 | typename MultiMap::const_iterator 91 | MultiMap::lower_bound(KeyType key) const { 92 | SearchBound bound = rs_.GetSearchBound(key); 93 | return std::lower_bound(data_.begin() + bound.begin, 94 | data_.begin() + bound.end, key, 95 | [](const value_type& lhs, const KeyType& rhs) { 96 | return lhs.first < rhs; 97 | }); 98 | } 99 | 100 | template 101 | typename MultiMap::const_iterator 102 | MultiMap::find(KeyType key) const { 103 | auto iter = lower_bound(key); 104 | return iter != data_.end() && iter->first == key ? iter : data_.end(); 105 | } 106 | 107 | } // namespace rs -------------------------------------------------------------------------------- /include/rs/radix_spline.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "common.h" 9 | 10 | namespace rs { 11 | 12 | // Approximates a cumulative distribution function (CDF) using spline 13 | // interpolation. 14 | template 15 | class RadixSpline { 16 | public: 17 | RadixSpline() = default; 18 | 19 | RadixSpline(KeyType min_key, KeyType max_key, size_t num_keys, 20 | size_t num_radix_bits, size_t num_shift_bits, size_t max_error, 21 | std::vector radix_table, 22 | std::vector> spline_points) 23 | : min_key_(min_key), 24 | max_key_(max_key), 25 | num_keys_(num_keys), 26 | num_radix_bits_(num_radix_bits), 27 | num_shift_bits_(num_shift_bits), 28 | max_error_(max_error), 29 | radix_table_(std::move(radix_table)), 30 | spline_points_(std::move(spline_points)) {} 31 | 32 | // Returns the estimated position of `key`. 33 | double GetEstimatedPosition(const KeyType key) const { 34 | // Truncate to data boundaries. 35 | if (key <= min_key_) return 0; 36 | if (key >= max_key_) return num_keys_ - 1; 37 | 38 | // Find spline segment with `key` ∈ (spline[index - 1], spline[index]]. 39 | const size_t index = GetSplineSegment(key); 40 | const Coord down = spline_points_[index - 1]; 41 | const Coord up = spline_points_[index]; 42 | 43 | // Compute slope. 44 | const double x_diff = up.x - down.x; 45 | const double y_diff = up.y - down.y; 46 | const double slope = y_diff / x_diff; 47 | 48 | // Interpolate. 49 | const double key_diff = key - down.x; 50 | return std::fma(key_diff, slope, down.y); 51 | } 52 | 53 | // Returns a search bound [begin, end) around the estimated position. 54 | SearchBound GetSearchBound(const KeyType key) const { 55 | const size_t estimate = GetEstimatedPosition(key); 56 | const size_t begin = (estimate < max_error_) ? 0 : (estimate - max_error_); 57 | // `end` is exclusive. 58 | const size_t end = (estimate + max_error_ + 2 > num_keys_) 59 | ? num_keys_ 60 | : (estimate + max_error_ + 2); 61 | return SearchBound{begin, end}; 62 | } 63 | 64 | // Returns the size in bytes. 65 | size_t GetSize() const { 66 | return sizeof(*this) + radix_table_.size() * sizeof(uint32_t) + 67 | spline_points_.size() * sizeof(Coord); 68 | } 69 | 70 | private: 71 | // Returns the index of the spline point that marks the end of the spline 72 | // segment that contains the `key`: `key` ∈ (spline[index - 1], spline[index]] 73 | size_t GetSplineSegment(const KeyType key) const { 74 | // Narrow search range using radix table. 75 | const KeyType prefix = (key - min_key_) >> num_shift_bits_; 76 | assert(prefix + 1 < radix_table_.size()); 77 | const uint32_t begin = radix_table_[prefix]; 78 | const uint32_t end = radix_table_[prefix + 1]; 79 | 80 | if (end - begin < 32) { 81 | // Do linear search over narrowed range. 82 | uint32_t current = begin; 83 | while (spline_points_[current].x < key) ++current; 84 | return current; 85 | } 86 | 87 | // Do binary search over narrowed range. 88 | const auto lb = std::lower_bound( 89 | spline_points_.begin() + begin, spline_points_.begin() + end, key, 90 | [](const Coord& coord, const KeyType key) { 91 | return coord.x < key; 92 | }); 93 | return std::distance(spline_points_.begin(), lb); 94 | } 95 | 96 | KeyType min_key_; 97 | KeyType max_key_; 98 | size_t num_keys_; 99 | size_t num_radix_bits_; 100 | size_t num_shift_bits_; 101 | size_t max_error_; 102 | 103 | std::vector radix_table_; 104 | std::vector> spline_points_; 105 | 106 | template 107 | friend class Serializer; 108 | }; 109 | 110 | } // namespace rs -------------------------------------------------------------------------------- /include/rs/serializer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "radix_spline.h" 6 | 7 | namespace rs { 8 | 9 | template 10 | class Serializer { 11 | public: 12 | // Serializes the `rs` model and appends it to `bytes`. 13 | static void ToBytes(const RadixSpline& rs, std::string* bytes) { 14 | std::stringstream buffer; 15 | 16 | // Scalar members. 17 | buffer.write(reinterpret_cast(&rs.min_key_), sizeof(KeyType)); 18 | buffer.write(reinterpret_cast(&rs.max_key_), sizeof(KeyType)); 19 | buffer.write(reinterpret_cast(&rs.num_keys_), sizeof(size_t)); 20 | buffer.write(reinterpret_cast(&rs.num_radix_bits_), 21 | sizeof(size_t)); 22 | buffer.write(reinterpret_cast(&rs.num_shift_bits_), 23 | sizeof(size_t)); 24 | buffer.write(reinterpret_cast(&rs.max_error_), sizeof(size_t)); 25 | 26 | // Radix table. 27 | const size_t radix_table_size = rs.radix_table_.size(); 28 | buffer.write(reinterpret_cast(&radix_table_size), 29 | sizeof(size_t)); 30 | for (size_t i = 0; i < rs.radix_table_.size(); ++i) { 31 | buffer.write(reinterpret_cast(&rs.radix_table_[i]), 32 | sizeof(uint32_t)); 33 | } 34 | 35 | // Spline points. 36 | const size_t spline_points_size = rs.spline_points_.size(); 37 | buffer.write(reinterpret_cast(&spline_points_size), 38 | sizeof(size_t)); 39 | for (size_t i = 0; i < rs.spline_points_.size(); ++i) { 40 | buffer.write(reinterpret_cast(&rs.spline_points_[i].x), 41 | sizeof(KeyType)); 42 | buffer.write(reinterpret_cast(&rs.spline_points_[i].y), 43 | sizeof(double)); 44 | } 45 | 46 | bytes->append(buffer.str()); 47 | } 48 | 49 | static RadixSpline FromBytes(const std::string& bytes) { 50 | std::istringstream in(bytes); 51 | 52 | RadixSpline rs; 53 | 54 | // Scalar members. 55 | in.read(reinterpret_cast(&rs.min_key_), sizeof(KeyType)); 56 | in.read(reinterpret_cast(&rs.max_key_), sizeof(KeyType)); 57 | in.read(reinterpret_cast(&rs.num_keys_), sizeof(size_t)); 58 | in.read(reinterpret_cast(&rs.num_radix_bits_), sizeof(size_t)); 59 | in.read(reinterpret_cast(&rs.num_shift_bits_), sizeof(size_t)); 60 | in.read(reinterpret_cast(&rs.max_error_), sizeof(size_t)); 61 | 62 | // Radix table. 63 | size_t radix_table_size; 64 | in.read(reinterpret_cast(&radix_table_size), sizeof(size_t)); 65 | rs.radix_table_.resize(radix_table_size); 66 | for (int i = 0; i < rs.radix_table_.size(); ++i) { 67 | in.read(reinterpret_cast(&rs.radix_table_[i]), sizeof(uint32_t)); 68 | } 69 | 70 | // Spline points. 71 | size_t spline_points_size; 72 | in.read(reinterpret_cast(&spline_points_size), sizeof(size_t)); 73 | rs.spline_points_.resize(spline_points_size); 74 | for (int i = 0; i < rs.spline_points_.size(); ++i) { 75 | in.read(reinterpret_cast(&rs.spline_points_[i].x), 76 | sizeof(KeyType)); 77 | in.read(reinterpret_cast(&rs.spline_points_[i].y), sizeof(double)); 78 | } 79 | 80 | return rs; 81 | } 82 | }; 83 | 84 | } // namespace rs 85 | -------------------------------------------------------------------------------- /test/multi_map_test.cc: -------------------------------------------------------------------------------- 1 | #include "include/rs/multi_map.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "gtest/gtest.h" 7 | 8 | namespace { 9 | 10 | TEST(MultiMapTest, SimpleFind) { 11 | std::vector> data = {{1ull, 'a'}, 12 | {12ull, 'c'}, 13 | {7ull, 'b'}, // Unsorted. 14 | {42ull, 'd'}}; 15 | rs::MultiMap rs_multi_map(data.begin(), data.end()); 16 | 17 | // Positive lookups (keys). 18 | ASSERT_EQ(1u, rs_multi_map.find(1)->first); 19 | ASSERT_EQ(7u, rs_multi_map.find(7)->first); 20 | ASSERT_EQ(12u, rs_multi_map.find(12)->first); 21 | ASSERT_EQ(42u, rs_multi_map.find(42)->first); 22 | 23 | // Positive lookups (values). 24 | ASSERT_EQ('a', rs_multi_map.find(1)->second); 25 | ASSERT_EQ('b', rs_multi_map.find(7)->second); 26 | ASSERT_EQ('c', rs_multi_map.find(12)->second); 27 | ASSERT_EQ('d', rs_multi_map.find(42)->second); 28 | 29 | // Negative lookups. 30 | ASSERT_EQ(rs_multi_map.end(), rs_multi_map.find(0)); 31 | ASSERT_EQ(rs_multi_map.end(), rs_multi_map.find(2)); 32 | ASSERT_EQ(rs_multi_map.end(), rs_multi_map.find(6)); 33 | ASSERT_EQ(rs_multi_map.end(), rs_multi_map.find(8)); 34 | ASSERT_EQ(rs_multi_map.end(), rs_multi_map.find(11)); 35 | ASSERT_EQ(rs_multi_map.end(), rs_multi_map.find(13)); 36 | ASSERT_EQ(rs_multi_map.end(), rs_multi_map.find(41)); 37 | ASSERT_EQ(rs_multi_map.end(), rs_multi_map.find(43)); 38 | } 39 | 40 | TEST(MultiMapTest, LowerBoundFind) { 41 | std::vector> data = {{1ull, 'a'}, 42 | {12ull, 'c'}, 43 | {7ull, 'b'}, // Unsorted. 44 | {42ull, 'd'}}; 45 | rs::MultiMap rs_multi_map(data.begin(), data.end()); 46 | 47 | // Direct-hit lookups (keys). 48 | ASSERT_EQ(1u, rs_multi_map.lower_bound(1)->first); 49 | ASSERT_EQ(7u, rs_multi_map.lower_bound(7)->first); 50 | ASSERT_EQ(12u, rs_multi_map.lower_bound(12)->first); 51 | ASSERT_EQ(42u, rs_multi_map.lower_bound(42)->first); 52 | 53 | // Direct-hit lookups (values). 54 | ASSERT_EQ('a', rs_multi_map.lower_bound(1)->second); 55 | ASSERT_EQ('b', rs_multi_map.lower_bound(7)->second); 56 | ASSERT_EQ('c', rs_multi_map.lower_bound(12)->second); 57 | ASSERT_EQ('d', rs_multi_map.lower_bound(42)->second); 58 | 59 | // Negative lookups (keys). 60 | ASSERT_EQ(1u, rs_multi_map.lower_bound(0)->first); 61 | ASSERT_EQ(7u, rs_multi_map.lower_bound(2)->first); 62 | ASSERT_EQ(7u, rs_multi_map.lower_bound(6)->first); 63 | ASSERT_EQ(12u, rs_multi_map.lower_bound(8)->first); 64 | ASSERT_EQ(12u, rs_multi_map.lower_bound(11)->first); 65 | ASSERT_EQ(42u, rs_multi_map.lower_bound(13)->first); 66 | ASSERT_EQ(42u, rs_multi_map.lower_bound(41)->first); 67 | ASSERT_EQ(rs_multi_map.end(), rs_multi_map.lower_bound(43)); 68 | } 69 | 70 | const size_t kNumKeys = 500; 71 | const size_t kNumLookups = 500; 72 | 73 | TEST(MultiMapTest, Random) { 74 | // Create random rs::MultiMap and std::multimap. 75 | std::vector> entries; 76 | entries.reserve(kNumKeys); 77 | std::mt19937 randomness_generator(8128); 78 | std::uniform_int_distribution distribution(0, kNumKeys * 10); 79 | while (entries.size() < kNumKeys) { 80 | entries.emplace_back(distribution(randomness_generator), entries.size()); 81 | } 82 | rs::MultiMap map(entries.begin(), entries.end()); 83 | std::multimap ref(entries.begin(), entries.end()); 84 | 85 | // Look up every key in the generated range 86 | for (size_t lookup_key = 0; lookup_key < kNumKeys * 10 + 10; ++lookup_key) { 87 | // Check lower bound 88 | auto map_iter = map.lower_bound(lookup_key); 89 | auto ref_iter = ref.lower_bound(lookup_key); 90 | 91 | // Found something -> iterate until the end and remember all values found. 92 | std::set found_values_in_map; 93 | std::set found_values_in_ref; 94 | while (map_iter != map.end() && ref_iter != ref.end()) { 95 | ASSERT_EQ(ref_iter->first, map_iter->first); 96 | found_values_in_map.insert(map_iter->second); 97 | found_values_in_ref.insert(ref_iter->second); 98 | ++ref_iter; 99 | ++map_iter; 100 | } 101 | 102 | // Both should be at the end now. 103 | ASSERT_EQ(map.end(), map_iter); 104 | ASSERT_EQ(ref.end(), ref_iter); 105 | 106 | // Check that all encountered values are the same. 107 | ASSERT_EQ(found_values_in_ref.size(), found_values_in_map.size()); 108 | auto val_iter_map = found_values_in_map.begin(); 109 | auto val_iter_ref = found_values_in_ref.begin(); 110 | for (size_t i = 0; i < found_values_in_ref.size(); ++i) { 111 | ASSERT_EQ(*val_iter_ref, *val_iter_map); 112 | ++val_iter_map; 113 | ++val_iter_ref; 114 | } 115 | } 116 | } 117 | 118 | } // namespace -------------------------------------------------------------------------------- /test/radix_spline_test.cc: -------------------------------------------------------------------------------- 1 | #include "include/rs/radix_spline.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "gtest/gtest.h" 7 | #include "include/rs/builder.h" 8 | #include "include/rs/serializer.h" 9 | 10 | const size_t kNumKeys = 1000; 11 | // Number of iterations (seeds) of random positive and negative test cases. 12 | const size_t kNumIterations = 10; 13 | const size_t kNumRadixBits = 18; 14 | const size_t kMaxError = 32; 15 | 16 | namespace { 17 | 18 | // *** Helper methods *** 19 | 20 | template 21 | std::vector CreateDenseKeys() { 22 | std::vector keys; 23 | keys.reserve(kNumKeys); 24 | for (size_t i = 0; i < kNumKeys; ++i) keys.push_back(i); 25 | return keys; 26 | } 27 | 28 | template 29 | std::vector CreateUniqueRandomKeys(size_t seed) { 30 | std::unordered_set keys; 31 | keys.reserve(kNumKeys); 32 | std::mt19937 g(seed); 33 | std::uniform_int_distribution d(std::numeric_limits::min(), 34 | std::numeric_limits::max()); 35 | while (keys.size() < kNumKeys) keys.insert(d(g)); 36 | std::vector sorted_keys(keys.begin(), keys.end()); 37 | std::sort(sorted_keys.begin(), sorted_keys.end()); 38 | return sorted_keys; 39 | } 40 | 41 | // Creates lognormal distributed keys, possibly with duplicates. 42 | template 43 | std::vector CreateSkewedKeys(size_t seed) { 44 | std::vector keys; 45 | keys.reserve(kNumKeys); 46 | 47 | // Generate lognormal values. 48 | std::mt19937 g(seed); 49 | std::lognormal_distribution d(/*mean*/ 0, /*stddev=*/2); 50 | std::vector lognormal_values; 51 | lognormal_values.reserve(kNumKeys); 52 | for (size_t i = 0; i < kNumKeys; ++i) lognormal_values.push_back(d(g)); 53 | const auto min_max = 54 | std::minmax_element(lognormal_values.begin(), lognormal_values.end()); 55 | const double min = *min_max.first; 56 | const double max = *min_max.second; 57 | const double diff = max - min; 58 | 59 | // Scale values to the entire `KeyType` domain. 60 | const auto domain = 61 | std::numeric_limits::max() - std::numeric_limits::min(); 62 | for (size_t i = 0; i < kNumKeys; ++i) { 63 | const double ratio = (lognormal_values[i] - min) / diff; 64 | keys.push_back(ratio * domain); 65 | } 66 | 67 | std::sort(keys.begin(), keys.end()); 68 | return keys; 69 | } 70 | 71 | template 72 | rs::RadixSpline CreateRadixSpline(const std::vector& keys) { 73 | auto min = std::numeric_limits::min(); 74 | auto max = std::numeric_limits::max(); 75 | if (keys.size() > 0) { 76 | min = keys.front(); 77 | max = keys.back(); 78 | } 79 | rs::Builder rsb(min, max, kNumRadixBits, kMaxError); 80 | for (const auto& key : keys) rsb.AddKey(key); 81 | return rsb.Finalize(); 82 | } 83 | 84 | template 85 | bool BoundContains(const std::vector& keys, rs::SearchBound bound, 86 | KeyType key) { 87 | const auto it = std::lower_bound(keys.begin() + bound.begin, 88 | keys.begin() + bound.end, key); 89 | if (it == keys.end()) return false; 90 | return *it == key; 91 | } 92 | 93 | // *** Tests *** 94 | 95 | template 96 | struct RadixSplineTest : public testing::Test { 97 | using KeyType = T; 98 | }; 99 | 100 | using AllKeyTypes = testing::Types; 101 | TYPED_TEST_SUITE(RadixSplineTest, AllKeyTypes); 102 | 103 | TYPED_TEST(RadixSplineTest, AddAndLookupDenseKeys) { 104 | using KeyType = typename TestFixture::KeyType; 105 | const auto keys = CreateDenseKeys(); 106 | const auto rs = CreateRadixSpline(keys); 107 | for (const auto& key : keys) 108 | EXPECT_TRUE(BoundContains(keys, rs.GetSearchBound(key), key)) 109 | << "key: " << key; 110 | } 111 | 112 | TYPED_TEST(RadixSplineTest, AddAndLookupRandomKeysPositiveLookups) { 113 | using KeyType = typename TestFixture::KeyType; 114 | for (size_t i = 0; i < kNumIterations; ++i) { 115 | const auto keys = CreateUniqueRandomKeys(/*seed=*/i); 116 | const auto rs = CreateRadixSpline(keys); 117 | for (const auto& key : keys) 118 | EXPECT_TRUE(BoundContains(keys, rs.GetSearchBound(key), key)) 119 | << "key: " << key; 120 | } 121 | } 122 | 123 | TYPED_TEST(RadixSplineTest, AddAndLookupRandomIntegersNegativeLookups) { 124 | using KeyType = typename TestFixture::KeyType; 125 | for (size_t i = 0; i < kNumIterations; ++i) { 126 | const auto keys = CreateUniqueRandomKeys(/*seed=*/42 + i); 127 | const auto lookup_keys = CreateUniqueRandomKeys(/*seed=*/815 + i); 128 | const auto rs = CreateRadixSpline(keys); 129 | for (const auto& key : lookup_keys) { 130 | if (!BoundContains(keys, rs::SearchBound{0, keys.size()}, key)) 131 | EXPECT_FALSE(BoundContains(keys, rs.GetSearchBound(key), key)) 132 | << "key: " << key; 133 | } 134 | } 135 | } 136 | 137 | TYPED_TEST(RadixSplineTest, 138 | AddAndLookupRandomIntegersWithDuplicatesPositiveLookups) { 139 | using KeyType = typename TestFixture::KeyType; 140 | 141 | // Duplicate every key once. 142 | auto duplicated_keys = CreateUniqueRandomKeys(/*seed=*/42); 143 | const size_t size = duplicated_keys.size(); 144 | for (size_t i = 0; i < size; ++i) 145 | duplicated_keys.push_back(duplicated_keys[i]); 146 | std::sort(duplicated_keys.begin(), duplicated_keys.end()); 147 | 148 | const auto rs = CreateRadixSpline(duplicated_keys); 149 | for (const auto& key : duplicated_keys) 150 | EXPECT_TRUE(BoundContains(duplicated_keys, rs.GetSearchBound(key), key)) 151 | << "key: " << key; 152 | } 153 | 154 | TYPED_TEST(RadixSplineTest, AddAndLookupSkewedKeysPositiveLookups) { 155 | using KeyType = typename TestFixture::KeyType; 156 | for (size_t i = 0; i < kNumIterations; ++i) { 157 | const auto keys = CreateSkewedKeys(/*seed=*/i); 158 | const auto rs = CreateRadixSpline(keys); 159 | for (const auto& key : keys) 160 | EXPECT_TRUE(BoundContains(keys, rs.GetSearchBound(key), key)) 161 | << "key: " << key; 162 | } 163 | } 164 | 165 | TYPED_TEST(RadixSplineTest, AddAndLookupSkewedKeysNegativeLookups) { 166 | using KeyType = typename TestFixture::KeyType; 167 | for (size_t i = 0; i < kNumIterations; ++i) { 168 | const auto keys = CreateSkewedKeys(/*seed=*/42 + i); 169 | const auto lookup_keys = CreateSkewedKeys(/*seed=*/815 + i); 170 | const auto rs = CreateRadixSpline(keys); 171 | for (const auto& key : lookup_keys) { 172 | if (!BoundContains(keys, rs::SearchBound{0, keys.size()}, key)) 173 | EXPECT_FALSE(BoundContains(keys, rs.GetSearchBound(key), key)) 174 | << "key: " << key; 175 | } 176 | } 177 | } 178 | 179 | TYPED_TEST(RadixSplineTest, GetEstimatedPosKeyOutOfRange) { 180 | using KeyType = typename TestFixture::KeyType; 181 | const std::vector keys = {1, 2, 3}; 182 | const auto rs = CreateRadixSpline(keys); 183 | EXPECT_EQ(rs.GetEstimatedPosition(0), 0u); 184 | EXPECT_EQ(rs.GetEstimatedPosition(4), keys.size() - 1); 185 | } 186 | 187 | TYPED_TEST(RadixSplineTest, NoKey) { 188 | using KeyType = typename TestFixture::KeyType; 189 | const std::vector keys; 190 | const auto rs = CreateRadixSpline(keys); 191 | // We expect the size to be at most the size of rs::RadixSpline and the size 192 | // of the pre-allocated radix table. 193 | EXPECT_TRUE(rs.GetSize() <= 194 | sizeof(rs::RadixSpline) + 195 | ((1ull << kNumRadixBits) + 1) * sizeof(uint32_t)); 196 | } 197 | 198 | TYPED_TEST(RadixSplineTest, SingleKey) { 199 | using KeyType = typename TestFixture::KeyType; 200 | const auto key = std::numeric_limits::min(); 201 | const std::vector keys = {key}; 202 | const auto rs = CreateRadixSpline(keys); 203 | EXPECT_EQ(rs.GetEstimatedPosition(key), 0u); 204 | EXPECT_TRUE(BoundContains(keys, rs.GetSearchBound(key), key)) 205 | << "key: " << key; 206 | } 207 | 208 | TYPED_TEST(RadixSplineTest, TwoKeys) { 209 | using KeyType = typename TestFixture::KeyType; 210 | const auto key1 = std::numeric_limits::min(); 211 | const auto key2 = std::numeric_limits::max(); 212 | const std::vector keys = {key1, key2}; 213 | const auto rs = CreateRadixSpline(keys); 214 | for (const auto& key : keys) 215 | EXPECT_TRUE(BoundContains(keys, rs.GetSearchBound(key), key)) 216 | << "key: " << key; 217 | } 218 | 219 | TYPED_TEST(RadixSplineTest, AllMinKeys) { 220 | using KeyType = typename TestFixture::KeyType; 221 | const auto key = std::numeric_limits::min(); 222 | const std::vector keys(kNumKeys, key); 223 | const auto rs = CreateRadixSpline(keys); 224 | EXPECT_TRUE(BoundContains(keys, rs.GetSearchBound(key), key)) 225 | << "key: " << key; 226 | } 227 | 228 | TYPED_TEST(RadixSplineTest, AllMaxKeys) { 229 | using KeyType = typename TestFixture::KeyType; 230 | const auto key = std::numeric_limits::max(); 231 | const std::vector keys(kNumKeys, key); 232 | const auto rs = CreateRadixSpline(keys); 233 | EXPECT_TRUE(BoundContains(keys, rs.GetSearchBound(key), key)) 234 | << "key: " << key; 235 | } 236 | 237 | TYPED_TEST(RadixSplineTest, Serialize) { 238 | using KeyType = typename TestFixture::KeyType; 239 | const auto keys = CreateDenseKeys(); 240 | const auto rs = CreateRadixSpline(keys); 241 | 242 | rs::Serializer serializer; 243 | 244 | // Serialize. 245 | std::string bytes; 246 | serializer.ToBytes(rs, &bytes); 247 | 248 | // Deserialize. 249 | const auto rs_deserialized = serializer.FromBytes(bytes); 250 | 251 | ASSERT_EQ(rs.GetSize(), rs_deserialized.GetSize()); 252 | for (const auto& key : keys) 253 | ASSERT_EQ(rs.GetEstimatedPosition(key), 254 | rs_deserialized.GetEstimatedPosition(key)); 255 | } 256 | 257 | } // namespace --------------------------------------------------------------------------------