├── .clang-format ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── server ├── CMakeLists.txt ├── command_registry.cc ├── command_registry.h ├── common_types.cc ├── common_types.h ├── config_flags.cc ├── config_flags.h ├── conn_context.cc ├── conn_context.h ├── db_slice.cc ├── db_slice.h ├── debugcmd.cc ├── debugcmd.h ├── dfly_main.cc ├── dfly_protocol.h ├── dragonfly_connection.cc ├── dragonfly_connection.h ├── dragonfly_listener.cc ├── dragonfly_listener.h ├── engine_shard_set.cc ├── engine_shard_set.h ├── main_service.cc ├── main_service.h ├── memcache_parser.cc ├── memcache_parser.h ├── memcache_parser_test.cc ├── op_status.h ├── redis_parser.cc ├── redis_parser.h ├── redis_parser_test.cc ├── reply_builder.cc ├── reply_builder.h ├── resp_expr.cc ├── resp_expr.h ├── test_utils.cc └── test_utils.h └── string_set ├── CMakeLists.txt ├── string_set.cc ├── string_set.h └── string_set_test.cc /.clang-format: -------------------------------------------------------------------------------- 1 | # --- 2 | # We'll use defaults from the Google style, but with 2 columns indentation. 3 | BasedOnStyle: Google 4 | IndentWidth: 2 5 | ColumnLimit: 100 6 | --- 7 | Language: Cpp 8 | AllowShortLoopsOnASingleLine: false 9 | AllowShortFunctionsOnASingleLine: false 10 | AllowShortIfStatementsOnASingleLine: false 11 | AlwaysBreakTemplateDeclarations: false 12 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 13 | DerivePointerAlignment: false 14 | PointerAlignment: Left 15 | BasedOnStyle: Google 16 | ColumnLimit: 100 17 | --- 18 | Language: Proto 19 | BasedOnStyle: Google 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | build-* 3 | .vscode/*.db 4 | .vscode/settings.json 5 | third_party 6 | genfiles/* 7 | *.sublime-* 8 | .tags 9 | !third_party/include/* 10 | *.pyc 11 | /CMakeLists.txt.user 12 | _deps 13 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "helio"] 2 | path = helio 3 | url = https://github.com/romange/helio.git 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15 FATAL_ERROR) 2 | set(PROJECT_CONTACT romange@gmail.com) 3 | 4 | enable_testing() 5 | 6 | set(CMAKE_EXPORT_COMPILE_COMMANDS 1) 7 | 8 | # Set targets in folders 9 | set_property(GLOBAL PROPERTY USE_FOLDERS ON) 10 | project(DRAGONFLY C CXX) 11 | set(CMAKE_CXX_STANDARD 17) 12 | 13 | # We must define all the required variables from the root cmakefile, otherwise 14 | # they just disappear. 15 | set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/helio/cmake" ${CMAKE_MODULE_PATH}) 16 | option(BUILD_SHARED_LIBS "Build shared libraries" OFF) 17 | 18 | include(third_party) 19 | include(internal) 20 | 21 | Message(STATUS "THIRD_PARTY_LIB_DIR ${THIRD_PARTY_LIB_DIR}") 22 | 23 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 24 | include_directories(helio) 25 | 26 | add_subdirectory(helio) 27 | add_subdirectory(server) 28 | add_subdirectory(string_set) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 10 | 11 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 12 | 13 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 14 | 15 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 16 | 17 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 18 | 19 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 20 | 21 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 22 | 23 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 24 | 25 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 26 | 27 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 28 | 29 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 30 | 31 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 32 | 33 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 34 | 35 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 36 | You must cause any modified files to carry prominent notices stating that You changed the files; and 37 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 38 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 39 | 40 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 41 | 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 42 | 43 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 44 | 45 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 46 | 47 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 48 | 49 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # midi-redis 2 | 3 | A toy memory store that supports basic commands like `SET` and `GET` for both memcached and redis protocols. 4 | In addition, it supports redis `PING` command. 5 | 6 | Demo features include: 7 | 1. High throughput reaching millions of QPS on a single node. 8 | 2. TLS support. 9 | 3. Pipelining mode. 10 | 11 | ## Building from source 12 | I've tested the build on Ubuntu 21.04+. 13 | 14 | 15 | ``` 16 | git clone --recursive https://github.com/romange/midi-redis 17 | cd midi-redis && ./helio/blaze.sh -release 18 | cd build-opt && ninja midi-redis 19 | 20 | ``` 21 | 22 | ## Running 23 | 24 | ``` 25 | ./midi-redis --logtostderr 26 | ``` 27 | 28 | for more options, run `./midi-redis --help` 29 | -------------------------------------------------------------------------------- /server/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(midi-redis dfly_main.cc) 2 | cxx_link(midi-redis base dragonfly_lib) 3 | 4 | add_library(dragonfly_lib command_registry.cc common_types.cc config_flags.cc 5 | conn_context.cc db_slice.cc debugcmd.cc dragonfly_listener.cc 6 | dragonfly_connection.cc engine_shard_set.cc 7 | main_service.cc memcache_parser.cc 8 | redis_parser.cc resp_expr.cc reply_builder.cc) 9 | 10 | cxx_link(dragonfly_lib uring_fiber_lib 11 | fibers_ext strings_lib http_server_lib tls_lib) 12 | 13 | add_library(dfly_test_lib test_utils.cc) 14 | cxx_link(dfly_test_lib dragonfly_lib gtest_main_ext) 15 | 16 | cxx_test(redis_parser_test dfly_test_lib LABELS DFLY) 17 | cxx_test(memcache_parser_test dfly_test_lib LABELS DFLY) 18 | -------------------------------------------------------------------------------- /server/command_registry.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/command_registry.h" 6 | 7 | #include "absl/strings/str_cat.h" 8 | #include "base/bits.h" 9 | #include "base/logging.h" 10 | #include "server/conn_context.h" 11 | 12 | using namespace std; 13 | 14 | namespace dfly { 15 | 16 | using absl::StrAppend; 17 | using absl::StrCat; 18 | 19 | CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key, 20 | int8_t last_key, int8_t step) 21 | : name_(name), opt_mask_(mask), arity_(arity), first_key_(first_key), last_key_(last_key), 22 | step_key_(step) { 23 | } 24 | 25 | uint32_t CommandId::OptCount(uint32_t mask) { 26 | return absl::popcount(mask); 27 | } 28 | 29 | CommandRegistry::CommandRegistry() { 30 | CommandId cd("COMMAND", CO::RANDOM | CO::LOADING | CO::STALE, 0, 0, 0, 0); 31 | 32 | cd.SetHandler([this](const auto& args, auto* cntx) { return Command(args, cntx); }); 33 | const char* nm = cd.name(); 34 | cmd_map_.emplace(nm, std::move(cd)); 35 | } 36 | 37 | void CommandRegistry::Command(CmdArgList args, ConnectionContext* cntx) { 38 | size_t sz = cmd_map_.size(); 39 | string resp = StrCat("*", sz, "\r\n"); 40 | 41 | for (const auto& val : cmd_map_) { 42 | const CommandId& cd = val.second; 43 | StrAppend(&resp, "*6\r\n$", strlen(cd.name()), "\r\n", cd.name(), "\r\n"); 44 | StrAppend(&resp, ":", int(cd.arity()), "\r\n"); 45 | StrAppend(&resp, "*", CommandId::OptCount(cd.opt_mask()), "\r\n"); 46 | 47 | for (uint32_t i = 0; i < 32; ++i) { 48 | unsigned obit = (1u << i); 49 | if (cd.opt_mask() & obit) { 50 | const char* name = CO::OptName(CO::CommandOpt{obit}); 51 | StrAppend(&resp, "+", name, "\r\n"); 52 | } 53 | } 54 | 55 | StrAppend(&resp, ":", cd.first_key_pos(), "\r\n"); 56 | StrAppend(&resp, ":", cd.last_key_pos(), "\r\n"); 57 | StrAppend(&resp, ":", cd.key_arg_step(), "\r\n"); 58 | } 59 | 60 | cntx->SendRespBlob(resp); 61 | } 62 | 63 | namespace CO { 64 | 65 | const char* OptName(CO::CommandOpt fl) { 66 | using namespace CO; 67 | 68 | switch (fl) { 69 | case WRITE: 70 | return "write"; 71 | case READONLY: 72 | return "readonly"; 73 | case DENYOOM: 74 | return "denyoom"; 75 | case FAST: 76 | return "fast"; 77 | case STALE: 78 | return "stale"; 79 | case LOADING: 80 | return "loading"; 81 | case RANDOM: 82 | return "random"; 83 | } 84 | return ""; 85 | } 86 | 87 | } // namespace CO 88 | 89 | } // namespace dfly 90 | -------------------------------------------------------------------------------- /server/command_registry.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #include "base/function2.hpp" 13 | #include "server/common_types.h" 14 | 15 | namespace dfly { 16 | 17 | class ConnectionContext; 18 | 19 | namespace CO { 20 | 21 | enum CommandOpt : uint32_t { 22 | READONLY = 1, 23 | FAST = 2, 24 | WRITE = 4, 25 | LOADING = 8, 26 | DENYOOM = 0x10, // use-memory in redis. 27 | STALE = 0x20, 28 | RANDOM = 0x40, 29 | }; 30 | 31 | const char* OptName(CommandOpt fl); 32 | 33 | }; // namespace CO 34 | 35 | class CommandId { 36 | public: 37 | using Handler = std::function; 38 | 39 | /** 40 | * @brief Construct a new Command Id object 41 | * 42 | * @param name 43 | * @param mask 44 | * @param arity - positive if command has fixed number of required arguments 45 | * negative if command has minimum number of required arguments, but may have 46 | * more. 47 | * @param first_key - position of first key in argument list 48 | * @param last_key - position of last key in argument list, 49 | * -1 means the last key index is (arg_length - 1), -2 means that the last key 50 | * index is (arg_length - 2). 51 | * @param step - step count for locating repeating keys 52 | */ 53 | CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key, int8_t last_key, 54 | int8_t step); 55 | 56 | const char* name() const { 57 | return name_; 58 | } 59 | 60 | int arity() const { 61 | return arity_; 62 | } 63 | 64 | uint32_t opt_mask() const { 65 | return opt_mask_; 66 | } 67 | 68 | int8_t first_key_pos() const { 69 | return first_key_; 70 | } 71 | 72 | int8_t last_key_pos() const { 73 | return last_key_; 74 | } 75 | 76 | bool is_multi_key() const { 77 | return last_key_ != first_key_; 78 | } 79 | 80 | int8_t key_arg_step() const { 81 | return step_key_; 82 | } 83 | 84 | CommandId& SetHandler(Handler f) { 85 | handler_ = std::move(f); 86 | return *this; 87 | } 88 | 89 | void Invoke(CmdArgList args, ConnectionContext* cntx) const { 90 | handler_(std::move(args), cntx); 91 | } 92 | 93 | static const char* OptName(CO::CommandOpt fl); 94 | static uint32_t OptCount(uint32_t mask); 95 | 96 | private: 97 | const char* name_; 98 | 99 | uint32_t opt_mask_; 100 | int8_t arity_; 101 | int8_t first_key_; 102 | int8_t last_key_; 103 | int8_t step_key_; 104 | 105 | Handler handler_; 106 | }; 107 | 108 | class CommandRegistry { 109 | absl::flat_hash_map cmd_map_; 110 | 111 | public: 112 | CommandRegistry(); 113 | 114 | CommandRegistry& operator<<(CommandId cmd) { 115 | const char* k = cmd.name(); 116 | cmd_map_.emplace(k, std::move(cmd)); 117 | 118 | return *this; 119 | } 120 | 121 | const CommandId* Find(std::string_view cmd) const { 122 | auto it = cmd_map_.find(cmd); 123 | return it == cmd_map_.end() ? nullptr : &it->second; 124 | } 125 | 126 | CommandId* Find(std::string_view cmd) { 127 | auto it = cmd_map_.find(cmd); 128 | return it == cmd_map_.end() ? nullptr : &it->second; 129 | } 130 | 131 | using TraverseCb = std::function; 132 | 133 | void Traverse(TraverseCb cb) { 134 | for (const auto& k_v : cmd_map_) { 135 | cb(k_v.first, k_v.second); 136 | } 137 | } 138 | 139 | private: 140 | // Implements COMMAND functionality. 141 | void Command(CmdArgList args, ConnectionContext* cntx); 142 | }; 143 | 144 | } // namespace dfly 145 | -------------------------------------------------------------------------------- /server/common_types.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/common_types.h" 6 | 7 | #include 8 | 9 | #include "base/logging.h" 10 | 11 | namespace dfly { 12 | 13 | using std::string; 14 | 15 | 16 | string WrongNumArgsError(std::string_view cmd) { 17 | return absl::StrCat("wrong number of arguments for '", cmd, "' command"); 18 | } 19 | 20 | const char kSyntaxErr[] = "syntax error"; 21 | const char kWrongTypeErr[] = "-WRONGTYPE Operation against a key holding the wrong kind of value"; 22 | const char kKeyNotFoundErr[] = "no such key"; 23 | const char kInvalidIntErr[] = "value is not an integer or out of range"; 24 | const char kUintErr[] = "value is out of range, must be positive"; 25 | const char kInvalidFloatErr[] = "value is not a valid float"; 26 | const char kInvalidScoreErr[] = "resulting score is not a number (NaN)"; 27 | const char kDbIndOutOfRangeErr[] = "DB index is out of range"; 28 | const char kInvalidDbIndErr[] = "invalid DB index"; 29 | const char kSameObjErr[] = "source and destination objects are the same"; 30 | 31 | } // namespace dfly 32 | 33 | namespace std { 34 | 35 | ostream& operator<<(ostream& os, dfly::CmdArgList ras) { 36 | os << "["; 37 | if (!ras.empty()) { 38 | for (size_t i = 0; i < ras.size() - 1; ++i) { 39 | os << dfly::ArgS(ras, i) << ","; 40 | } 41 | os << dfly::ArgS(ras, ras.size() - 1); 42 | } 43 | os << "]"; 44 | 45 | return os; 46 | } 47 | 48 | } // namespace std -------------------------------------------------------------------------------- /server/common_types.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | namespace dfly { 16 | 17 | using DbIndex = uint16_t; 18 | using ShardId = uint16_t; 19 | 20 | using MutableStrSpan = absl::Span; 21 | using CmdArgList = absl::Span; 22 | using CmdArgVec = std::vector; 23 | 24 | constexpr DbIndex kInvalidDbId = DbIndex(-1); 25 | constexpr ShardId kInvalidSid = ShardId(-1); 26 | 27 | struct ConnectionState { 28 | enum Mask : uint32_t { 29 | ASYNC_DISPATCH = 1, // whether a command is handled via async dispatch. 30 | CONN_CLOSING = 2, // could be because of unrecoverable error or planned action. 31 | }; 32 | 33 | uint32_t mask = 0; // A bitmask of Mask values. 34 | 35 | bool IsClosing() const { 36 | return mask & CONN_CLOSING; 37 | } 38 | 39 | bool IsRunViaDispatch() const { 40 | return mask & ASYNC_DISPATCH; 41 | } 42 | }; 43 | 44 | template inline ShardId Shard(const View& v, ShardId shard_num) { 45 | XXH64_hash_t hash = XXH64(v.data(), v.size(), 120577); 46 | return hash % shard_num; 47 | } 48 | 49 | using MainValue = std::string; 50 | using MainTable = absl::flat_hash_map; 51 | using MainIterator = MainTable::iterator; 52 | 53 | class EngineShard; 54 | 55 | inline std::string_view ArgS(CmdArgList args, size_t i) { 56 | auto arg = args[i]; 57 | return std::string_view(arg.data(), arg.size()); 58 | } 59 | 60 | inline void ToUpper(const MutableStrSpan* val) { 61 | for (auto& c : *val) { 62 | c = absl::ascii_toupper(c); 63 | } 64 | } 65 | 66 | inline MutableStrSpan ToMSS(absl::Span span) { 67 | return MutableStrSpan{reinterpret_cast(span.data()), span.size()}; 68 | } 69 | 70 | std::string WrongNumArgsError(std::string_view cmd); 71 | 72 | } // namespace dfly 73 | 74 | namespace std { 75 | ostream& operator<<(ostream& os, dfly::CmdArgList args); 76 | 77 | } // namespace std -------------------------------------------------------------------------------- /server/config_flags.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | #include "server/config_flags.h" 5 | 6 | namespace dfly { 7 | 8 | bool ValidateConfigEnum(const char* nm, const std::string& val, const ConfigEnum* ptr, unsigned len, 9 | int* dest) { 10 | for (unsigned i = 0; i < len; ++i) { 11 | if (val == ptr[i].first) { 12 | *dest = ptr[i].second; 13 | return true; 14 | } 15 | } 16 | return false; 17 | } 18 | 19 | } // namespace dfly 20 | -------------------------------------------------------------------------------- /server/config_flags.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #include "base/flags.h" 11 | 12 | namespace dfly { 13 | 14 | // DashStr - replaces all underscores to dash characters and keeps the rest as is. 15 | template class DashStr { 16 | public: 17 | DashStr(const char* s) { 18 | memcpy(str_, s, N); 19 | for (unsigned i = 0; i < N; ++i) { 20 | if (str_[i] == '_') 21 | str_[i] = '-'; 22 | } 23 | } 24 | 25 | const char* str() const { 26 | return str_; 27 | } 28 | 29 | private: 30 | char str_[N]; 31 | }; 32 | 33 | using ConfigEnum = std::pair; 34 | 35 | bool ValidateConfigEnum(const char* nm, const std::string& val, const ConfigEnum* ptr, unsigned len, 36 | int* dest); 37 | 38 | } // namespace dfly 39 | 40 | inline bool TrueValidator(const char* nm, const std::string& val) { 41 | return true; 42 | } 43 | 44 | #define DEFINE_CONFIG_VAR(type, shorttype, name, value, help, validator) \ 45 | namespace fL##shorttype { \ 46 | type FLAGS_##name = value; \ 47 | static type FLAGS_no##name = value; \ 48 | static ::dfly::DashStr _dash_##name(#name); \ 49 | static GFLAGS_NAMESPACE::FlagRegisterer o_##name( \ 50 | _dash_##name.str(), MAYBE_STRIPPED_HELP(help), __FILE__, &FLAGS_##name, &FLAGS_no##name); \ 51 | static const bool name##_val_reg = \ 52 | GFLAGS_NAMESPACE::RegisterFlagValidator(&FLAGS_##name, validator); \ 53 | } \ 54 | using fL##shorttype::FLAGS_##name 55 | 56 | #define BIND_CONFIG(var) [](const char* nm, auto val) { \ 57 | var = val; \ 58 | return true;} 59 | 60 | 61 | #define BIND_ENUM_CONFIG(enum_arr, dest_var) [](const char* nm, const std::string& val) { \ 62 | return ::dfly::ValidateConfigEnum(nm, val, enum_arr, ABSL_ARRAYSIZE(enum_arr), \ 63 | &(dest_var));} 64 | 65 | #define CONFIG_uint64(name,val, txt, validator) \ 66 | DEFINE_CONFIG_VAR(GFLAGS_NAMESPACE::uint64, U64, name, val, txt, validator) 67 | 68 | 69 | #define CONFIG_string(name, val, txt, validator) \ 70 | namespace fLS { \ 71 | using ::fLS::clstring; \ 72 | using ::fLS::StringFlagDestructor; \ 73 | static union { void* align; char s[sizeof(clstring)]; } s_##name[2]; \ 74 | clstring* const FLAGS_no##name = ::fLS:: \ 75 | dont_pass0toDEFINE_string(s_##name[0].s, \ 76 | val); \ 77 | static ::dfly::DashStr _dash_##name(#name); \ 78 | static GFLAGS_NAMESPACE::FlagRegisterer o_##name( \ 79 | _dash_##name.str(), MAYBE_STRIPPED_HELP(txt), __FILE__, \ 80 | FLAGS_no##name, new (s_##name[1].s) clstring(*FLAGS_no##name)); \ 81 | static StringFlagDestructor d_##name(s_##name[0].s, s_##name[1].s); \ 82 | extern GFLAGS_DLL_DEFINE_FLAG clstring& FLAGS_##name; \ 83 | using fLS::FLAGS_##name; \ 84 | clstring& FLAGS_##name = *FLAGS_no##name; \ 85 | static const bool name##_val_reg = \ 86 | GFLAGS_NAMESPACE::RegisterFlagValidator(&FLAGS_##name, validator); \ 87 | } \ 88 | using fLS::FLAGS_##name 89 | 90 | #define CONFIG_enum(name, val, txt, enum_arr, dest_var) \ 91 | CONFIG_string(name, val, txt, BIND_ENUM_CONFIG(enum_arr, dest_var)) 92 | 93 | -------------------------------------------------------------------------------- /server/conn_context.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/conn_context.h" 6 | 7 | #include "server/dragonfly_connection.h" 8 | 9 | namespace dfly { 10 | 11 | ConnectionContext::ConnectionContext(::io::Sink* stream, Connection* owner) 12 | : ReplyBuilder(owner->protocol(), stream), owner_(owner) { 13 | } 14 | 15 | Protocol ConnectionContext::protocol() const { 16 | return owner_->protocol(); 17 | } 18 | 19 | } // namespace dfly 20 | -------------------------------------------------------------------------------- /server/conn_context.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include "server/reply_builder.h" 8 | #include "server/common_types.h" 9 | 10 | namespace dfly { 11 | 12 | class Connection; 13 | class EngineShardSet; 14 | class CommandId; 15 | 16 | class ConnectionContext : public ReplyBuilder { 17 | public: 18 | ConnectionContext(::io::Sink* stream, Connection* owner); 19 | 20 | // TODO: to introduce proper accessors. 21 | const CommandId* cid = nullptr; 22 | EngineShardSet* shard_set = nullptr; 23 | 24 | Connection* owner() { 25 | return owner_; 26 | } 27 | 28 | Protocol protocol() const; 29 | 30 | ConnectionState conn_state; 31 | 32 | private: 33 | Connection* owner_; 34 | }; 35 | 36 | } // namespace dfly 37 | -------------------------------------------------------------------------------- /server/db_slice.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/db_slice.h" 6 | 7 | #include 8 | #include 9 | 10 | #include "base/logging.h" 11 | #include "server/engine_shard_set.h" 12 | #include "util/fiber_sched_algo.h" 13 | #include "util/proactor_base.h" 14 | 15 | namespace dfly { 16 | 17 | using namespace boost; 18 | using namespace std; 19 | using namespace util; 20 | 21 | DbSlice::DbSlice(uint32_t index, EngineShard* owner) : shard_id_(index), owner_(owner) { 22 | db_arr_.emplace_back(); 23 | CreateDbRedis(0); 24 | } 25 | 26 | DbSlice::~DbSlice() { 27 | for (auto& db : db_arr_) { 28 | if (!db.main_table) 29 | continue; 30 | db.main_table.reset(); 31 | } 32 | } 33 | 34 | void DbSlice::Reserve(DbIndex db_ind, size_t key_size) { 35 | ActivateDb(db_ind); 36 | 37 | auto& db = db_arr_[db_ind]; 38 | DCHECK(db.main_table); 39 | 40 | db.main_table->reserve(key_size); 41 | } 42 | 43 | auto DbSlice::Find(DbIndex db_index, std::string_view key) const -> OpResult { 44 | DCHECK_LT(db_index, db_arr_.size()); 45 | DCHECK(db_arr_[db_index].main_table); 46 | 47 | auto& db = db_arr_[db_index]; 48 | MainIterator it = db.main_table->find(key); 49 | 50 | if (it == db.main_table->end()) { 51 | return OpStatus::KEY_NOTFOUND; 52 | } 53 | 54 | return it; 55 | } 56 | 57 | auto DbSlice::AddOrFind(DbIndex db_index, std::string_view key) -> pair { 58 | DCHECK_LT(db_index, db_arr_.size()); 59 | DCHECK(db_arr_[db_index].main_table); 60 | 61 | auto& db = db_arr_[db_index]; 62 | 63 | pair res = db.main_table->emplace(key, MainValue{}); 64 | if (res.second) { // new entry 65 | db.stats.obj_memory_usage += res.first->first.capacity(); 66 | 67 | return make_pair(res.first, true); 68 | } 69 | 70 | return res; 71 | } 72 | 73 | void DbSlice::ActivateDb(DbIndex db_ind) { 74 | if (db_arr_.size() <= db_ind) 75 | db_arr_.resize(db_ind + 1); 76 | CreateDbRedis(db_ind); 77 | } 78 | 79 | void DbSlice::CreateDbRedis(unsigned index) { 80 | auto& db = db_arr_[index]; 81 | if (!db.main_table) { 82 | db.main_table.reset(new MainTable); 83 | } 84 | } 85 | 86 | void DbSlice::AddNew(DbIndex db_ind, std::string_view key, MainValue obj, uint64_t expire_at_ms) { 87 | CHECK(AddIfNotExist(db_ind, key, std::move(obj), expire_at_ms)); 88 | } 89 | 90 | bool DbSlice::AddIfNotExist(DbIndex db_ind, std::string_view key, MainValue obj, 91 | uint64_t expire_at_ms) { 92 | auto& db = db_arr_[db_ind]; 93 | 94 | auto [new_entry, success] = db.main_table->emplace(key, obj); 95 | if (!success) 96 | return false; // in this case obj won't be moved and will be destroyed during unwinding. 97 | 98 | db.stats.obj_memory_usage += (new_entry->first.capacity() + new_entry->second.capacity()); 99 | 100 | if (expire_at_ms) { 101 | // TODO 102 | } 103 | 104 | return true; 105 | } 106 | 107 | size_t DbSlice::DbSize(DbIndex db_ind) const { 108 | DCHECK_LT(db_ind, db_array_size()); 109 | 110 | if (IsDbValid(db_ind)) { 111 | return db_arr_[db_ind].main_table->size(); 112 | } 113 | return 0; 114 | } 115 | 116 | } // namespace dfly 117 | -------------------------------------------------------------------------------- /server/db_slice.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include "server/common_types.h" 8 | #include "server/op_status.h" 9 | 10 | namespace util { 11 | class ProactorBase; 12 | } 13 | 14 | namespace dfly { 15 | 16 | class DbSlice { 17 | struct InternalDbStats { 18 | // Object memory usage besides hash-table capacity. 19 | size_t obj_memory_usage = 0; 20 | }; 21 | 22 | public: 23 | DbSlice(uint32_t index, EngineShard* owner); 24 | ~DbSlice(); 25 | 26 | 27 | // Activates `db_ind` database if it does not exist (see ActivateDb below). 28 | void Reserve(DbIndex db_ind, size_t key_size); 29 | 30 | OpResult Find(DbIndex db_index, std::string_view key) const; 31 | 32 | // Return .second=true if insertion ocurred, false if we return the existing key. 33 | std::pair AddOrFind(DbIndex db_ind, std::string_view key); 34 | 35 | // Adds a new entry. Requires: key does not exist in this slice. 36 | void AddNew(DbIndex db_ind, std::string_view key, MainValue obj, uint64_t expire_at_ms); 37 | 38 | // Adds a new entry if a key does not exists. Returns true if insertion took place, 39 | // false otherwise. expire_at_ms equal to 0 - means no expiry. 40 | bool AddIfNotExist(DbIndex db_ind, std::string_view key, MainValue obj, uint64_t expire_at_ms); 41 | 42 | // Creates a database with index `db_ind`. If such database exists does nothing. 43 | void ActivateDb(DbIndex db_ind); 44 | 45 | size_t db_array_size() const { 46 | return db_arr_.size(); 47 | } 48 | 49 | bool IsDbValid(DbIndex id) const { 50 | return bool(db_arr_[id].main_table); 51 | } 52 | 53 | // Returns existing keys count in the db. 54 | size_t DbSize(DbIndex db_ind) const; 55 | 56 | ShardId shard_id() const { return shard_id_;} 57 | 58 | private: 59 | 60 | void CreateDbRedis(unsigned index); 61 | 62 | ShardId shard_id_; 63 | 64 | EngineShard* owner_; 65 | 66 | struct DbRedis { 67 | std::unique_ptr main_table; 68 | mutable InternalDbStats stats; 69 | }; 70 | 71 | std::vector db_arr_; 72 | }; 73 | 74 | } // namespace dfly 75 | -------------------------------------------------------------------------------- /server/debugcmd.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | #include "server/debugcmd.h" 5 | 6 | #include 7 | 8 | #include "base/logging.h" 9 | #include "server/engine_shard_set.h" 10 | 11 | namespace dfly { 12 | 13 | using namespace boost; 14 | using namespace std; 15 | 16 | static const char kUintErr[] = "value is out of range, must be positive"; 17 | 18 | struct PopulateBatch { 19 | uint64_t index[32]; 20 | uint64_t sz = 0; 21 | }; 22 | 23 | void DoPopulateBatch(std::string_view prefix, size_t val_size, const PopulateBatch& ps) { 24 | EngineShard* es = EngineShard::tlocal(); 25 | DbSlice& db_slice = es->db_slice; 26 | 27 | for (unsigned i = 0; i < ps.sz; ++i) { 28 | string key = absl::StrCat(prefix, ":", ps.index[i]); 29 | string val = absl::StrCat("value:", ps.index[i]); 30 | 31 | if (val.size() < val_size) { 32 | val.resize(val_size, 'x'); 33 | } 34 | auto [it, res] = db_slice.AddOrFind(0, key); 35 | if (res) { 36 | it->second = std::move(val); 37 | } 38 | } 39 | } 40 | 41 | DebugCmd::DebugCmd(EngineShardSet* ess, ConnectionContext* cntx) : ess_(ess), cntx_(cntx) { 42 | } 43 | 44 | void DebugCmd::Run(CmdArgList args) { 45 | std::string_view subcmd = ArgS(args, 1); 46 | if (subcmd == "HELP") { 47 | std::string_view help_arr[] = { 48 | "DEBUG [ [value] [opt] ...]. Subcommands are:", 49 | "POPULATE [] []", 50 | " Create string keys named key:. If is specified then", 51 | " it is used instead of the 'key' prefix.", 52 | "HELP", 53 | " Prints this help.", 54 | }; 55 | return cntx_->SendSimpleStrArr(help_arr, ABSL_ARRAYSIZE(help_arr)); 56 | } 57 | 58 | VLOG(1) << "subcmd " << subcmd; 59 | 60 | if (subcmd == "POPULATE") { 61 | return Populate(args); 62 | } 63 | 64 | string reply = absl::StrCat("Unknown subcommand or wrong number of arguments for '", subcmd, 65 | "'. Try DEBUG HELP."); 66 | return cntx_->SendError(reply); 67 | } 68 | 69 | void DebugCmd::Populate(CmdArgList args) { 70 | if (args.size() < 3 || args.size() > 5) { 71 | return cntx_->SendError( 72 | "Unknown subcommand or wrong number of arguments for 'populate'. Try DEBUG HELP."); 73 | } 74 | 75 | uint64_t total_count = 0; 76 | if (!absl::SimpleAtoi(ArgS(args, 2), &total_count)) 77 | return cntx_->SendError(kUintErr); 78 | std::string_view prefix{"key"}; 79 | 80 | if (args.size() > 3) { 81 | prefix = ArgS(args, 3); 82 | } 83 | uint32_t val_size = 0; 84 | if (args.size() > 4) { 85 | std::string_view str = ArgS(args, 4); 86 | if (!absl::SimpleAtoi(str, &val_size)) 87 | return cntx_->SendError(kUintErr); 88 | } 89 | 90 | size_t runners_count = ess_->pool()->size(); 91 | vector> ranges(runners_count - 1); 92 | uint64_t batch_size = total_count / runners_count; 93 | size_t from = 0; 94 | for (size_t i = 0; i < ranges.size(); ++i) { 95 | ranges[i].first = from; 96 | ranges[i].second = batch_size; 97 | from += batch_size; 98 | } 99 | ranges.emplace_back(from, total_count - from); 100 | 101 | auto distribute_cb = [this, val_size, prefix]( 102 | uint64_t from, uint64_t len) { 103 | string key = absl::StrCat(prefix, ":"); 104 | size_t prefsize = key.size(); 105 | std::vector ps(ess_->size(), PopulateBatch{}); 106 | 107 | for (uint64_t i = from; i < from + len; ++i) { 108 | absl::StrAppend(&key, i); 109 | ShardId sid = Shard(key, ess_->size()); 110 | key.resize(prefsize); 111 | 112 | auto& pops = ps[sid]; 113 | pops.index[pops.sz++] = i; 114 | if (pops.sz == 32) { 115 | ess_->Add(sid, [=, p = pops] { 116 | DoPopulateBatch(prefix, val_size, p); 117 | if (i % 100 == 0) { 118 | this_fiber::yield(); 119 | } 120 | }); 121 | 122 | // we capture pops by value so we can override it here. 123 | pops.sz = 0; 124 | } 125 | } 126 | 127 | ess_->RunBriefInParallel( 128 | [&](EngineShard* shard) { DoPopulateBatch(prefix, val_size, ps[shard->shard_id()]); }); 129 | }; 130 | vector fb_arr(ranges.size()); 131 | for (size_t i = 0; i < ranges.size(); ++i) { 132 | fb_arr[i] = ess_->pool()->at(i)->LaunchFiber(distribute_cb, ranges[i].first, ranges[i].second); 133 | } 134 | for (auto& fb : fb_arr) 135 | fb.join(); 136 | 137 | cntx_->SendOk(); 138 | } 139 | 140 | } // namespace dfly 141 | -------------------------------------------------------------------------------- /server/debugcmd.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include "server/conn_context.h" 8 | 9 | namespace dfly { 10 | 11 | class EngineShardSet; 12 | 13 | class DebugCmd { 14 | public: 15 | DebugCmd(EngineShardSet* ess, ConnectionContext* cntx); 16 | 17 | void Run(CmdArgList args); 18 | 19 | private: 20 | void Populate(CmdArgList args); 21 | 22 | EngineShardSet* ess_; 23 | ConnectionContext* cntx_; 24 | }; 25 | 26 | } // namespace dfly 27 | -------------------------------------------------------------------------------- /server/dfly_main.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "base/init.h" 6 | #include "server/main_service.h" 7 | #include "server/dragonfly_listener.h" 8 | #include "util/accept_server.h" 9 | #include "util/uring/uring_pool.h" 10 | #include "util/varz.h" 11 | 12 | ABSL_FLAG(int32_t, http_port, 8080, "Http port."); 13 | ABSL_DECLARE_FLAG(uint32_t, port); 14 | ABSL_DECLARE_FLAG(uint32_t, memcache_port); 15 | 16 | using namespace util; 17 | using absl::GetFlag; 18 | 19 | namespace dfly { 20 | 21 | void RunEngine(ProactorPool* pool, AcceptServer* acceptor, HttpListener<>* http) { 22 | Service service(pool); 23 | service.Init(acceptor); 24 | 25 | if (http) { 26 | service.RegisterHttp(http); 27 | } 28 | 29 | acceptor->AddListener(GetFlag(FLAGS_port), new Listener{Protocol::REDIS, &service}); 30 | auto mc_port = GetFlag(FLAGS_memcache_port); 31 | if (mc_port > 0) { 32 | acceptor->AddListener(mc_port, new Listener{Protocol::MEMCACHE, &service}); 33 | } 34 | 35 | acceptor->Run(); 36 | acceptor->Wait(); 37 | 38 | service.Shutdown(); 39 | } 40 | 41 | } // namespace dfly 42 | 43 | 44 | int main(int argc, char* argv[]) { 45 | MainInitGuard guard(&argc, &argv); 46 | 47 | CHECK_GT(GetFlag(FLAGS_port), 0u); 48 | 49 | uring::UringPool pp{1024}; 50 | pp.Run(); 51 | 52 | AcceptServer acceptor(&pp); 53 | HttpListener<>* http_listener = nullptr; 54 | 55 | if (GetFlag(FLAGS_http_port) >= 0) { 56 | http_listener = new HttpListener<>; 57 | http_listener->enable_metrics(); 58 | 59 | // Ownership over http_listener is moved to the acceptor. 60 | uint16_t port = acceptor.AddListener(GetFlag(FLAGS_http_port), http_listener); 61 | 62 | LOG(INFO) << "Started http service on port " << port; 63 | } 64 | 65 | dfly::RunEngine(&pp, &acceptor, http_listener); 66 | 67 | pp.Stop(); 68 | 69 | return 0; 70 | } 71 | -------------------------------------------------------------------------------- /server/dfly_protocol.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | namespace dfly { 10 | 11 | enum class Protocol : uint8_t { 12 | MEMCACHE = 1, 13 | REDIS = 2 14 | }; 15 | 16 | } // namespace dfly 17 | -------------------------------------------------------------------------------- /server/dragonfly_connection.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/dragonfly_connection.h" 6 | 7 | #include 8 | 9 | #include 10 | 11 | #include "base/io_buf.h" 12 | #include "base/logging.h" 13 | #include "server/command_registry.h" 14 | #include "server/conn_context.h" 15 | #include "server/main_service.h" 16 | #include "server/memcache_parser.h" 17 | #include "server/redis_parser.h" 18 | #include "util/fiber_sched_algo.h" 19 | #include "util/tls/tls_socket.h" 20 | 21 | using namespace util; 22 | using namespace std; 23 | namespace this_fiber = boost::this_fiber; 24 | namespace fibers = boost::fibers; 25 | 26 | namespace dfly { 27 | namespace { 28 | 29 | void SendProtocolError(RedisParser::Result pres, FiberSocketBase* peer) { 30 | string res("-ERR Protocol error: "); 31 | if (pres == RedisParser::BAD_BULKLEN) { 32 | res.append("invalid bulk length\r\n"); 33 | } else { 34 | CHECK_EQ(RedisParser::BAD_ARRAYLEN, pres); 35 | res.append("invalid multibulk length\r\n"); 36 | } 37 | 38 | auto size_res = peer->Send(::io::Buffer(res)); 39 | if (!size_res) { 40 | LOG(WARNING) << "Error " << size_res.error(); 41 | } 42 | } 43 | 44 | void RespToArgList(const RespVec& src, CmdArgVec* dest) { 45 | dest->resize(src.size()); 46 | for (size_t i = 0; i < src.size(); ++i) { 47 | (*dest)[i] = ToMSS(src[i].GetBuf()); 48 | } 49 | } 50 | 51 | constexpr size_t kMinReadSize = 256; 52 | 53 | } // namespace 54 | 55 | struct Connection::Shutdown { 56 | absl::flat_hash_map map; 57 | ShutdownHandle next_handle = 1; 58 | 59 | ShutdownHandle Add(ShutdownCb cb) { 60 | map[next_handle] = move(cb); 61 | return next_handle++; 62 | } 63 | 64 | void Remove(ShutdownHandle sh) { 65 | map.erase(sh); 66 | } 67 | }; 68 | 69 | Connection::Connection(Protocol protocol, Service* service, SSL_CTX* ctx) 70 | : service_(service), ctx_(ctx) { 71 | protocol_ = protocol; 72 | 73 | switch (protocol) { 74 | case Protocol::REDIS: 75 | redis_parser_.reset(new RedisParser); 76 | break; 77 | case Protocol::MEMCACHE: 78 | memcache_parser_.reset(new MemcacheParser); 79 | break; 80 | } 81 | } 82 | 83 | Connection::~Connection() { 84 | } 85 | 86 | void Connection::OnShutdown() { 87 | VLOG(1) << "Connection::OnShutdown"; 88 | if (shutdown_) { 89 | for (const auto& k_v : shutdown_->map) { 90 | k_v.second(); 91 | } 92 | } 93 | } 94 | 95 | auto Connection::RegisterShutdownHook(ShutdownCb cb) -> ShutdownHandle { 96 | if (!shutdown_) { 97 | shutdown_ = make_unique(); 98 | } 99 | return shutdown_->Add(std::move(cb)); 100 | } 101 | 102 | void Connection::UnregisterShutdownHook(ShutdownHandle id) { 103 | if (shutdown_) { 104 | shutdown_->Remove(id); 105 | if (shutdown_->map.empty()) 106 | shutdown_.reset(); 107 | } 108 | } 109 | 110 | void Connection::HandleRequests() { 111 | this_fiber::properties().set_name("DflyConnection"); 112 | 113 | LinuxSocketBase* lsb = static_cast(socket_.get()); 114 | // int val = 1; 115 | // CHECK_EQ(0, setsockopt(socket_->native_handle(), SOL_TCP, TCP_NODELAY, &val, sizeof(val))); 116 | 117 | auto ep = lsb->RemoteEndpoint(); 118 | 119 | std::unique_ptr tls_sock; 120 | if (ctx_) { 121 | tls_sock.reset(new tls::TlsSocket(socket_.get())); 122 | tls_sock->InitSSL(ctx_); 123 | 124 | FiberSocketBase::AcceptResult aresult = tls_sock->Accept(); 125 | if (!aresult) { 126 | LOG(WARNING) << "Error handshaking " << aresult.error().message(); 127 | return; 128 | } 129 | VLOG(1) << "TLS handshake succeeded"; 130 | } 131 | FiberSocketBase* peer = tls_sock ? (FiberSocketBase*)tls_sock.get() : socket_.get(); 132 | cc_.reset(new ConnectionContext(peer, this)); 133 | cc_->shard_set = &service_->shard_set(); 134 | 135 | InputLoop(peer); 136 | 137 | VLOG(1) << "Closed connection for peer " << ep; 138 | } 139 | 140 | void Connection::InputLoop(FiberSocketBase* peer) { 141 | base::IoBuf io_buf{kMinReadSize}; 142 | 143 | auto dispatch_fb = fibers::fiber(fibers::launch::dispatch, [&] { DispatchFiber(peer); }); 144 | ParserStatus status = OK; 145 | std::error_code ec; 146 | 147 | do { 148 | auto buf = io_buf.AppendBuffer(); 149 | ::io::Result recv_sz = peer->Recv(buf); 150 | 151 | if (!recv_sz) { 152 | ec = recv_sz.error(); 153 | status = OK; 154 | break; 155 | } 156 | 157 | io_buf.CommitWrite(*recv_sz); 158 | 159 | if (redis_parser_) 160 | status = ParseRedis(&io_buf); 161 | else { 162 | DCHECK(memcache_parser_); 163 | status = ParseMemcache(&io_buf); 164 | } 165 | 166 | if (status == NEED_MORE) { 167 | status = OK; 168 | } else if (status != OK) { 169 | break; 170 | } 171 | } while (peer->IsOpen() && !cc_->ec()); 172 | 173 | cc_->conn_state.mask |= ConnectionState::CONN_CLOSING; // Signal dispatch to close. 174 | evc_.notify(); 175 | dispatch_fb.join(); 176 | 177 | if (cc_->ec()) { 178 | ec = cc_->ec(); 179 | } else { 180 | if (status == ERROR) { 181 | VLOG(1) << "Error stats " << status; 182 | if (redis_parser_) { 183 | SendProtocolError(RedisParser::Result(parser_error_), peer); 184 | } else { 185 | string_view sv{"CLIENT_ERROR bad command line format\r\n"}; 186 | auto size_res = peer->Send(::io::Buffer(sv)); 187 | if (!size_res) { 188 | LOG(WARNING) << "Error " << size_res.error(); 189 | ec = size_res.error(); 190 | } 191 | } 192 | } 193 | } 194 | 195 | if (ec && !FiberSocketBase::IsConnClosed(ec)) { 196 | LOG(WARNING) << "Socket error " << ec; 197 | } 198 | } 199 | 200 | auto Connection::ParseRedis(base::IoBuf* io_buf) -> ParserStatus { 201 | RespVec args; 202 | CmdArgVec arg_vec; 203 | uint32_t consumed = 0; 204 | 205 | RedisParser::Result result = RedisParser::OK; 206 | 207 | do { 208 | result = redis_parser_->Parse(io_buf->InputBuffer(), &consumed, &args); 209 | 210 | if (result == RedisParser::OK && !args.empty()) { 211 | RespExpr& first = args.front(); 212 | if (first.type == RespExpr::STRING) { 213 | DVLOG(2) << "Got Args with first token " << ToSV(first.GetBuf()); 214 | } 215 | 216 | // An optimization to skip dispatch_q_ if no pipelining is identified. 217 | // We use ASYNC_DISPATCH as a lock to avoid out-of-order replies when the 218 | // dispatch fiber pulls the last record but is still processing the command and then this 219 | // fiber enters the condition below and executes out of order. 220 | bool is_sync_dispatch = !cc_->conn_state.IsRunViaDispatch(); 221 | if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf->InputLen()) { 222 | RespToArgList(args, &arg_vec); 223 | service_->DispatchCommand(CmdArgList{arg_vec.data(), arg_vec.size()}, cc_.get()); 224 | } else { 225 | // Dispatch via queue to speedup input reading, 226 | Request* req = FromArgs(std::move(args)); 227 | dispatch_q_.emplace_back(req); 228 | if (dispatch_q_.size() == 1) { 229 | evc_.notify(); 230 | } else if (dispatch_q_.size() > 10) { 231 | this_fiber::yield(); 232 | } 233 | } 234 | } 235 | io_buf->ConsumeInput(consumed); 236 | } while (RedisParser::OK == result && !cc_->ec()); 237 | 238 | parser_error_ = result; 239 | if (result == RedisParser::OK) 240 | return OK; 241 | 242 | if (result == RedisParser::INPUT_PENDING) 243 | return NEED_MORE; 244 | 245 | return ERROR; 246 | } 247 | 248 | auto Connection::ParseMemcache(base::IoBuf* io_buf) -> ParserStatus { 249 | MemcacheParser::Result result = MemcacheParser::OK; 250 | uint32_t consumed = 0; 251 | MemcacheParser::Command cmd; 252 | string_view value; 253 | do { 254 | string_view str = ToSV(io_buf->InputBuffer()); 255 | result = memcache_parser_->Parse(str, &consumed, &cmd); 256 | 257 | if (result != MemcacheParser::OK) { 258 | io_buf->ConsumeInput(consumed); 259 | break; 260 | } 261 | 262 | size_t total_len = consumed; 263 | if (MemcacheParser::IsStoreCmd(cmd.type)) { 264 | total_len += cmd.bytes_len + 2; 265 | if (io_buf->InputLen() >= total_len) { 266 | value = str.substr(consumed, cmd.bytes_len); 267 | // TODO: dispatch. 268 | } else { 269 | return NEED_MORE; 270 | } 271 | } 272 | 273 | // An optimization to skip dispatch_q_ if no pipelining is identified. 274 | // We use ASYNC_DISPATCH as a lock to avoid out-of-order replies when the 275 | // dispatch fiber pulls the last record but is still processing the command and then this 276 | // fiber enters the condition below and executes out of order. 277 | bool is_sync_dispatch = (cc_->conn_state.mask & ConnectionState::ASYNC_DISPATCH) == 0; 278 | if (dispatch_q_.empty() && is_sync_dispatch && consumed >= io_buf->InputLen()) { 279 | service_->DispatchMC(cmd, value, cc_.get()); 280 | } 281 | io_buf->ConsumeInput(consumed); 282 | } while (!cc_->ec()); 283 | 284 | parser_error_ = result; 285 | 286 | if (result == MemcacheParser::OK) 287 | return OK; 288 | 289 | if (result == MemcacheParser::INPUT_PENDING) 290 | return NEED_MORE; 291 | 292 | return ERROR; 293 | } 294 | 295 | // DispatchFiber handles commands coming from the InputLoop. 296 | // Thus, InputLoop can quickly read data from the input buffer, parse it and push 297 | // into the dispatch queue and DispatchFiber will run those commands asynchronously with InputLoop. 298 | // Note: in some cases, InputLoop may decide to dispatch directly and bypass the DispatchFiber. 299 | void Connection::DispatchFiber(util::FiberSocketBase* peer) { 300 | this_fiber::properties().set_name("DispatchFiber"); 301 | 302 | while (!cc_->ec()) { 303 | evc_.await([this] { return cc_->conn_state.IsClosing() || !dispatch_q_.empty(); }); 304 | if (cc_->conn_state.IsClosing()) 305 | break; 306 | 307 | std::unique_ptr req{dispatch_q_.front()}; 308 | dispatch_q_.pop_front(); 309 | 310 | cc_->SetBatchMode(!dispatch_q_.empty()); 311 | cc_->conn_state.mask |= ConnectionState::ASYNC_DISPATCH; 312 | service_->DispatchCommand(CmdArgList{req->args.data(), req->args.size()}, cc_.get()); 313 | cc_->conn_state.mask &= ~ConnectionState::ASYNC_DISPATCH; 314 | } 315 | 316 | cc_->conn_state.mask |= ConnectionState::CONN_CLOSING; 317 | } 318 | 319 | auto Connection::FromArgs(RespVec args) -> Request* { 320 | DCHECK(!args.empty()); 321 | size_t backed_sz = 0; 322 | for (const auto& arg : args) { 323 | CHECK_EQ(RespExpr::STRING, arg.type); 324 | backed_sz += arg.GetBuf().size(); 325 | } 326 | DCHECK(backed_sz); 327 | 328 | Request* req = new Request{args.size(), backed_sz}; 329 | 330 | auto* next = req->storage.data(); 331 | for (size_t i = 0; i < args.size(); ++i) { 332 | auto buf = args[i].GetBuf(); 333 | size_t s = buf.size(); 334 | memcpy(next, buf.data(), s); 335 | req->args[i] = MutableStrSpan(next, s); 336 | next += s; 337 | } 338 | 339 | return req; 340 | } 341 | 342 | } // namespace dfly 343 | -------------------------------------------------------------------------------- /server/dragonfly_connection.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include 10 | 11 | #include "base/io_buf.h" 12 | #include "server/common_types.h" 13 | #include "server/dfly_protocol.h" 14 | #include "server/resp_expr.h" 15 | #include "util/connection.h" 16 | #include "util/fibers/event_count.h" 17 | 18 | typedef struct ssl_ctx_st SSL_CTX; 19 | 20 | namespace dfly { 21 | 22 | class ConnectionContext; 23 | class RedisParser; 24 | class Service; 25 | class MemcacheParser; 26 | 27 | class Connection : public util::Connection { 28 | public: 29 | Connection(Protocol protocol, Service* service, SSL_CTX* ctx); 30 | ~Connection(); 31 | 32 | using error_code = std::error_code; 33 | using ShutdownCb = std::function; 34 | using ShutdownHandle = unsigned; 35 | 36 | ShutdownHandle RegisterShutdownHook(ShutdownCb cb); 37 | void UnregisterShutdownHook(ShutdownHandle id); 38 | 39 | Protocol protocol() const { 40 | return protocol_; 41 | } 42 | 43 | protected: 44 | void OnShutdown() override; 45 | 46 | private: 47 | enum ParserStatus { OK, NEED_MORE, ERROR }; 48 | 49 | void HandleRequests() final; 50 | 51 | void InputLoop(util::FiberSocketBase* peer); 52 | void DispatchFiber(util::FiberSocketBase* peer); 53 | 54 | ParserStatus ParseRedis(base::IoBuf* buf); 55 | ParserStatus ParseMemcache(base::IoBuf* buf); 56 | 57 | std::unique_ptr redis_parser_; 58 | std::unique_ptr memcache_parser_; 59 | Service* service_; 60 | SSL_CTX* ctx_; 61 | std::unique_ptr cc_; 62 | 63 | struct Request { 64 | absl::FixedArray args; 65 | absl::FixedArray storage; 66 | 67 | Request(size_t nargs, size_t capacity) : args(nargs), storage(capacity) { 68 | } 69 | Request(const Request&) = delete; 70 | }; 71 | 72 | static Request* FromArgs(RespVec args); 73 | 74 | std::deque dispatch_q_; // coordinated via evc_. 75 | util::fibers_ext::EventCount evc_; 76 | unsigned parser_error_ = 0; 77 | Protocol protocol_; 78 | struct Shutdown; 79 | std::unique_ptr shutdown_; 80 | }; 81 | 82 | } // namespace dfly 83 | -------------------------------------------------------------------------------- /server/dragonfly_listener.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/dragonfly_listener.h" 6 | 7 | #include 8 | 9 | #include "base/flags.h" 10 | #include "base/logging.h" 11 | #include "server/dragonfly_connection.h" 12 | #include "util/proactor_pool.h" 13 | 14 | using namespace std; 15 | 16 | ABSL_FLAG(uint32_t, conn_threads, 0, "Number of threads used for handing server connections"); 17 | ABSL_FLAG(bool, tls, false, ""); 18 | ABSL_FLAG(bool, conn_use_incoming_cpu, false, 19 | "If true uses incoming cpu of a socket in order to distribute" 20 | " incoming connections"); 21 | 22 | ABSL_FLAG(string, tls_client_cert_file, "", ""); 23 | ABSL_FLAG(string, tls_client_key_file, "", ""); 24 | 25 | using absl::GetFlag; 26 | 27 | enum TlsClientAuth { 28 | CL_AUTH_NO = 0, 29 | CL_AUTH_YES = 1, 30 | CL_AUTH_OPTIONAL = 2, 31 | }; 32 | 33 | 34 | namespace dfly { 35 | 36 | using namespace util; 37 | 38 | // To connect: openssl s_client -cipher "ADH:@SECLEVEL=0" -state -crlf -connect 127.0.0.1:6380 39 | static SSL_CTX* CreateSslCntx() { 40 | SSL_CTX* ctx = SSL_CTX_new(TLS_server_method()); 41 | 42 | if (GetFlag(FLAGS_tls_client_key_file).empty()) { 43 | // To connect - use openssl s_client -cipher with either: 44 | // "AECDH:@SECLEVEL=0" or "ADH:@SECLEVEL=0" setting. 45 | CHECK_EQ(1, SSL_CTX_set_cipher_list(ctx, "aNULL")); 46 | 47 | // To allow anonymous ciphers. 48 | SSL_CTX_set_security_level(ctx, 0); 49 | 50 | // you can still connect with redis-cli with : 51 | // redis-cli --tls --insecure --tls-ciphers "ADH:@SECLEVEL=0" 52 | LOG(WARNING) 53 | << "tls-client-key-file not set, no keys are loaded and anonymous ciphers are enabled. " 54 | << "Do not use in production!"; 55 | } else { // tls_client_key_file is set. 56 | auto key_file = GetFlag(FLAGS_tls_client_key_file); 57 | CHECK_EQ(1, 58 | SSL_CTX_use_PrivateKey_file(ctx, key_file.c_str(), SSL_FILETYPE_PEM)); 59 | 60 | auto cert_file = GetFlag(FLAGS_tls_client_cert_file); 61 | 62 | if (!cert_file.empty()) { 63 | // TO connect with redis-cli you need both tls-client-key-file and tls-client-cert-file 64 | // loaded. Use `redis-cli --tls -p 6380 --insecure PING` to test 65 | 66 | CHECK_EQ(1, SSL_CTX_use_certificate_chain_file(ctx, cert_file.c_str())); 67 | } 68 | CHECK_EQ(1, SSL_CTX_set_cipher_list(ctx, "DEFAULT")); 69 | } 70 | SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); 71 | 72 | SSL_CTX_set_options(ctx, SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); 73 | 74 | unsigned mask = SSL_VERIFY_NONE; 75 | 76 | // if (tls_auth_clients_opt) 77 | // mask |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT; 78 | SSL_CTX_set_verify(ctx, mask, NULL); 79 | 80 | CHECK_EQ(1, SSL_CTX_set_dh_auto(ctx, 1)); 81 | 82 | return ctx; 83 | } 84 | 85 | Listener::Listener(Protocol protocol, Service* e) : engine_(e), protocol_(protocol) { 86 | if (GetFlag(FLAGS_tls)) { 87 | OPENSSL_init_ssl(OPENSSL_INIT_SSL_DEFAULT, NULL); 88 | ctx_ = CreateSslCntx(); 89 | } 90 | } 91 | 92 | Listener::~Listener() { 93 | SSL_CTX_free(ctx_); 94 | } 95 | 96 | util::Connection* Listener::NewConnection(ProactorBase* proactor) { 97 | return new Connection{protocol_, engine_, ctx_}; 98 | } 99 | 100 | void Listener::PreShutdown() { 101 | } 102 | 103 | void Listener::PostShutdown() { 104 | } 105 | 106 | // We can limit number of threads handling dragonfly connections. 107 | ProactorBase* Listener::PickConnectionProactor(LinuxSocketBase* sock) { 108 | util::ProactorPool* pp = pool(); 109 | uint32_t total = GetFlag(FLAGS_conn_threads); 110 | uint32_t id = kuint32max; 111 | 112 | if (total == 0 || total > pp->size()) { 113 | total = pp->size(); 114 | } 115 | 116 | if (GetFlag(FLAGS_conn_use_incoming_cpu)) { 117 | int fd = sock->native_handle(); 118 | 119 | int cpu, napi_id; 120 | socklen_t len = sizeof(cpu); 121 | 122 | CHECK_EQ(0, getsockopt(fd, SOL_SOCKET, SO_INCOMING_CPU, &cpu, &len)); 123 | CHECK_EQ(0, getsockopt(fd, SOL_SOCKET, SO_INCOMING_NAPI_ID, &napi_id, &len)); 124 | VLOG(1) << "CPU/NAPI for connection " << fd << " is " << cpu << "/" << napi_id; 125 | 126 | vector ids = pool()->MapCpuToThreads(cpu); 127 | if (!ids.empty()) { 128 | id = ids.front(); 129 | } 130 | } 131 | 132 | if (id == kuint32max) { 133 | id = next_id_.fetch_add(1, std::memory_order_relaxed); 134 | } 135 | 136 | return pp->at(id % total); 137 | } 138 | 139 | } // namespace dfly 140 | -------------------------------------------------------------------------------- /server/dragonfly_listener.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include "util/listener_interface.h" 8 | #include "server/dfly_protocol.h" 9 | 10 | typedef struct ssl_ctx_st SSL_CTX; 11 | 12 | namespace dfly { 13 | 14 | class Service; 15 | 16 | class Listener : public util::ListenerInterface { 17 | public: 18 | Listener(Protocol protocol, Service*); 19 | ~Listener(); 20 | 21 | private: 22 | util::Connection* NewConnection(util::ProactorBase* proactor) final; 23 | util::ProactorBase* PickConnectionProactor(util::LinuxSocketBase* sock) final; 24 | 25 | void PreShutdown(); 26 | 27 | void PostShutdown(); 28 | 29 | Service* engine_; 30 | 31 | std::atomic_uint32_t next_id_{0}; 32 | Protocol protocol_; 33 | SSL_CTX* ctx_ = nullptr; 34 | }; 35 | 36 | } // namespace dfly 37 | -------------------------------------------------------------------------------- /server/engine_shard_set.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/engine_shard_set.h" 6 | 7 | #include "base/logging.h" 8 | #include "util/fiber_sched_algo.h" 9 | #include "util/varz.h" 10 | 11 | namespace dfly { 12 | 13 | using namespace std; 14 | using namespace boost; 15 | using util::FiberProps; 16 | 17 | thread_local EngineShard* EngineShard::shard_ = nullptr; 18 | constexpr size_t kQueueLen = 64; 19 | 20 | EngineShard::EngineShard(ShardId index) 21 | : db_slice(index, this), queue_(kQueueLen) { 22 | fiber_q_ = fibers::fiber([this, index] { 23 | this_fiber::properties().set_name(absl::StrCat("shard_queue", index)); 24 | queue_.Run(); 25 | }); 26 | } 27 | 28 | EngineShard::~EngineShard() { 29 | queue_.Shutdown(); 30 | fiber_q_.join(); 31 | } 32 | 33 | void EngineShard::InitThreadLocal(ShardId index) { 34 | CHECK(shard_ == nullptr) << index; 35 | shard_ = new EngineShard(index); 36 | } 37 | 38 | void EngineShard::DestroyThreadLocal() { 39 | if (!shard_) 40 | return; 41 | 42 | uint32_t index = shard_->db_slice.shard_id(); 43 | delete shard_; 44 | shard_ = nullptr; 45 | 46 | VLOG(1) << "Shard reset " << index; 47 | } 48 | 49 | void EngineShardSet::Init(uint32_t sz) { 50 | CHECK_EQ(0u, size()); 51 | 52 | shard_queue_.resize(sz); 53 | } 54 | 55 | void EngineShardSet::InitThreadLocal(ShardId index) { 56 | EngineShard::InitThreadLocal(index); 57 | shard_queue_[index] = EngineShard::tlocal()->GetQueue(); 58 | } 59 | 60 | } // namespace dfly 61 | -------------------------------------------------------------------------------- /server/engine_shard_set.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include "server/db_slice.h" 8 | #include "util/fibers/fibers_ext.h" 9 | #include "util/fibers/fiberqueue_threadpool.h" 10 | #include "util/proactor_pool.h" 11 | 12 | namespace dfly { 13 | 14 | using ShardId = uint16_t; 15 | 16 | class EngineShard { 17 | public: 18 | DbSlice db_slice; 19 | 20 | //EngineShard() is private down below. 21 | ~EngineShard(); 22 | 23 | static void InitThreadLocal(ShardId index); 24 | static void DestroyThreadLocal(); 25 | 26 | static EngineShard* tlocal() { 27 | return shard_; 28 | } 29 | 30 | ShardId shard_id() const { 31 | return db_slice.shard_id(); 32 | } 33 | 34 | ::util::fibers_ext::FiberQueue* GetQueue() { 35 | return &queue_; 36 | } 37 | 38 | private: 39 | EngineShard(ShardId index); 40 | 41 | ::util::fibers_ext::FiberQueue queue_; 42 | ::boost::fibers::fiber fiber_q_; 43 | 44 | static thread_local EngineShard* shard_; 45 | }; 46 | 47 | class EngineShardSet { 48 | public: 49 | explicit EngineShardSet(util::ProactorPool* pp) : pp_(pp) { 50 | } 51 | 52 | uint32_t size() const { 53 | return uint32_t(shard_queue_.size()); 54 | } 55 | 56 | util::ProactorPool* pool() { 57 | return pp_; 58 | } 59 | 60 | void Init(uint32_t size); 61 | void InitThreadLocal(ShardId index); 62 | 63 | template auto Await(ShardId sid, F&& f) { 64 | return shard_queue_[sid]->Await(std::forward(f)); 65 | } 66 | 67 | template auto Add(ShardId sid, F&& f) { 68 | assert(sid < shard_queue_.size()); 69 | return shard_queue_[sid]->Add(std::forward(f)); 70 | } 71 | 72 | template void RunBriefInParallel(U&& func); 73 | 74 | private: 75 | util::ProactorPool* pp_; 76 | std::vector shard_queue_; 77 | }; 78 | 79 | /** 80 | * @brief 81 | * 82 | * @tparam U - a function that receives EngineShard* argument and returns void. 83 | * @param func 84 | */ 85 | template void EngineShardSet::RunBriefInParallel(U&& func) { 86 | util::fibers_ext::BlockingCounter bc{size()}; 87 | 88 | for (uint32_t i = 0; i < size(); ++i) { 89 | util::ProactorBase* dest = pp_->at(i); 90 | dest->DispatchBrief([f = std::forward(func), bc]() mutable { 91 | f(EngineShard::tlocal()); 92 | bc.Dec(); 93 | }); 94 | } 95 | bc.Wait(); 96 | } 97 | 98 | } // namespace dfly 99 | -------------------------------------------------------------------------------- /server/main_service.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/main_service.h" 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include "base/logging.h" 14 | #include "base/flags.h" 15 | #include "server/conn_context.h" 16 | #include "server/debugcmd.h" 17 | #include "util/metrics/metrics.h" 18 | #include "util/uring/uring_fiber_algo.h" 19 | #include "util/varz.h" 20 | 21 | ABSL_FLAG(uint32_t, port, 6380, "Redis port"); 22 | ABSL_FLAG(uint32_t, memcache_port, 0, "Memcached port"); 23 | 24 | namespace dfly { 25 | 26 | using namespace std; 27 | using namespace util; 28 | using base::VarzValue; 29 | 30 | namespace fibers = ::boost::fibers; 31 | namespace this_fiber = ::boost::this_fiber; 32 | 33 | namespace { 34 | 35 | DEFINE_VARZ(VarzMapAverage, request_latency_usec); 36 | DEFINE_VARZ(VarzQps, ping_qps); 37 | DEFINE_VARZ(VarzQps, set_qps); 38 | 39 | optional engine_varz; 40 | metrics::CounterFamily cmd_req("requests_total", "Number of served redis requests"); 41 | 42 | } // namespace 43 | 44 | Service::Service(ProactorPool* pp) : shard_set_(pp), pp_(*pp) { 45 | CHECK(pp); 46 | RegisterCommands(); 47 | engine_varz.emplace("engine", [this] { return GetVarzStats(); }); 48 | } 49 | 50 | Service::~Service() { 51 | } 52 | 53 | void Service::Init(util::AcceptServer* acceptor) { 54 | uint32_t shard_num = pp_.size() > 1 ? pp_.size() - 1 : pp_.size(); 55 | shard_set_.Init(shard_num); 56 | 57 | pp_.Await([&](uint32_t index, ProactorBase* pb) { 58 | if (index < shard_count()) { 59 | shard_set_.InitThreadLocal(index); 60 | } 61 | }); 62 | 63 | request_latency_usec.Init(&pp_); 64 | ping_qps.Init(&pp_); 65 | set_qps.Init(&pp_); 66 | cmd_req.Init(&pp_, {"type"}); 67 | } 68 | 69 | void Service::Shutdown() { 70 | VLOG(1) << "Service::Shutdown"; 71 | 72 | engine_varz.reset(); 73 | request_latency_usec.Shutdown(); 74 | ping_qps.Shutdown(); 75 | set_qps.Shutdown(); 76 | 77 | shard_set_.RunBriefInParallel([&](EngineShard*) { EngineShard::DestroyThreadLocal(); }); 78 | } 79 | 80 | void Service::DispatchCommand(CmdArgList args, ConnectionContext* cntx) { 81 | CHECK(!args.empty()); 82 | DCHECK_NE(0u, shard_set_.size()) << "Init was not called"; 83 | 84 | ToUpper(&args[0]); 85 | 86 | VLOG(2) << "Got: " << args; 87 | 88 | string_view cmd_str = ArgS(args, 0); 89 | const CommandId* cid = registry_.Find(cmd_str); 90 | 91 | if (cid == nullptr) { 92 | return cntx->SendError(absl::StrCat("unknown command `", cmd_str, "`")); 93 | } 94 | 95 | if ((cid->arity() > 0 && args.size() != size_t(cid->arity())) || 96 | (cid->arity() < 0 && args.size() < size_t(-cid->arity()))) { 97 | return cntx->SendError(WrongNumArgsError(cmd_str)); 98 | } 99 | uint64_t start_usec = ProactorBase::GetMonotonicTimeNs(), end_usec; 100 | cntx->cid = cid; 101 | cmd_req.Inc({cid->name()}); 102 | cid->Invoke(args, cntx); 103 | end_usec = ProactorBase::GetMonotonicTimeNs(); 104 | 105 | request_latency_usec.IncBy(cmd_str, (end_usec - start_usec) / 1000); 106 | } 107 | 108 | void Service::DispatchMC(const MemcacheParser::Command& cmd, std::string_view value, 109 | ConnectionContext* cntx) { 110 | absl::InlinedVector args; 111 | char cmd_name[16]; 112 | char set_opt[4] = {0}; 113 | 114 | switch (cmd.type) { 115 | case MemcacheParser::REPLACE: 116 | strcpy(cmd_name, "SET"); 117 | strcpy(set_opt, "XX"); 118 | break; 119 | case MemcacheParser::SET: 120 | strcpy(cmd_name, "SET"); 121 | break; 122 | case MemcacheParser::ADD: 123 | strcpy(cmd_name, "SET"); 124 | strcpy(set_opt, "NX"); 125 | break; 126 | case MemcacheParser::GET: 127 | strcpy(cmd_name, "GET"); 128 | break; 129 | default: 130 | cntx->SendMCClientError("bad command line format"); 131 | return; 132 | } 133 | 134 | args.emplace_back(cmd_name, strlen(cmd_name)); 135 | char* key = const_cast(cmd.key.data()); 136 | args.emplace_back(key, cmd.key.size()); 137 | 138 | if (MemcacheParser::IsStoreCmd(cmd.type)) { 139 | char* v = const_cast(value.data()); 140 | args.emplace_back(v, value.size()); 141 | 142 | if (set_opt[0]) { 143 | args.emplace_back(set_opt, strlen(set_opt)); 144 | } 145 | } 146 | 147 | CmdArgList arg_list{args.data(), args.size()}; 148 | DispatchCommand(arg_list, cntx); 149 | } 150 | 151 | void Service::RegisterHttp(HttpListenerBase* listener) { 152 | CHECK_NOTNULL(listener); 153 | } 154 | 155 | void Service::Ping(CmdArgList args, ConnectionContext* cntx) { 156 | if (args.size() > 2) { 157 | return cntx->SendError("wrong number of arguments for 'ping' command"); 158 | } 159 | ping_qps.Inc(); 160 | 161 | if (args.size() == 1) { 162 | return cntx->SendSimpleRespString("PONG"); 163 | } 164 | std::string_view arg = ArgS(args, 1); 165 | DVLOG(2) << "Ping " << arg; 166 | 167 | return cntx->SendSimpleRespString(arg); 168 | } 169 | 170 | void Service::Set(CmdArgList args, ConnectionContext* cntx) { 171 | set_qps.Inc(); 172 | 173 | string_view key = ArgS(args, 1); 174 | string_view val = ArgS(args, 2); 175 | VLOG(2) << "Set " << key << " " << val; 176 | 177 | ShardId sid = Shard(key, shard_count()); 178 | shard_set_.Await(sid, [&] { 179 | EngineShard* es = EngineShard::tlocal(); 180 | auto [it, res] = es->db_slice.AddOrFind(0, key); 181 | it->second = val; 182 | }); 183 | 184 | cntx->SendStored(); 185 | } 186 | 187 | void Service::Get(CmdArgList args, ConnectionContext* cntx) { 188 | string_view key = ArgS(args, 1); 189 | ShardId sid = Shard(key, shard_count()); 190 | 191 | OpResult opres; 192 | 193 | shard_set_.Await(sid, [&] { 194 | EngineShard* es = EngineShard::tlocal(); 195 | OpResult res = es->db_slice.Find(0, key); 196 | if (res) { 197 | opres.value() = res.value()->second; 198 | } else { 199 | opres = res.status(); 200 | } 201 | }); 202 | 203 | if (opres) { 204 | cntx->SendGetReply(key, 0, opres.value()); 205 | } else if (opres.status() == OpStatus::KEY_NOTFOUND) { 206 | cntx->SendGetNotFound(); 207 | } 208 | cntx->EndMultilineReply(); 209 | } 210 | 211 | void Service::Debug(CmdArgList args, ConnectionContext* cntx) { 212 | ToUpper(&args[1]); 213 | 214 | DebugCmd dbg_cmd{&shard_set_, cntx}; 215 | 216 | return dbg_cmd.Run(args); 217 | } 218 | 219 | VarzValue::Map Service::GetVarzStats() { 220 | VarzValue::Map res; 221 | 222 | atomic_ulong num_keys{0}; 223 | shard_set_.RunBriefInParallel([&](EngineShard* es) { num_keys += es->db_slice.DbSize(0); }); 224 | res.emplace_back("keys", VarzValue::FromInt(num_keys.load())); 225 | 226 | return res; 227 | } 228 | 229 | using ServiceFunc = void (Service::*)(CmdArgList args, ConnectionContext* cntx); 230 | inline CommandId::Handler HandlerFunc(Service* se, ServiceFunc f) { 231 | return [=](CmdArgList args, ConnectionContext* cntx) { return (se->*f)(args, cntx); }; 232 | } 233 | 234 | #define HFUNC(x) SetHandler(HandlerFunc(this, &Service::x)) 235 | 236 | void Service::RegisterCommands() { 237 | using CI = CommandId; 238 | 239 | registry_ << CI{"PING", CO::STALE | CO::FAST, -1, 0, 0, 0}.HFUNC(Ping) 240 | << CI{"SET", CO::WRITE | CO::DENYOOM, -3, 1, 1, 1}.HFUNC(Set) 241 | << CI{"GET", CO::READONLY | CO::FAST, 2, 1, 1, 1}.HFUNC(Get) 242 | << CI{"DEBUG", CO::RANDOM | CO::READONLY, -2, 0, 0, 0}.HFUNC(Debug); 243 | } 244 | 245 | } // namespace dfly 246 | -------------------------------------------------------------------------------- /server/main_service.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include "base/varz_value.h" 8 | #include "server/command_registry.h" 9 | #include "server/engine_shard_set.h" 10 | #include "util/http/http_handler.h" 11 | #include "server/memcache_parser.h" 12 | 13 | namespace util { 14 | class AcceptServer; 15 | } // namespace util 16 | 17 | namespace dfly { 18 | 19 | class Service { 20 | public: 21 | using error_code = std::error_code; 22 | 23 | explicit Service(util::ProactorPool* pp); 24 | ~Service(); 25 | 26 | void RegisterHttp(util::HttpListenerBase* listener); 27 | 28 | void Init(util::AcceptServer* acceptor); 29 | 30 | void Shutdown(); 31 | 32 | void DispatchCommand(CmdArgList args, ConnectionContext* cntx); 33 | void DispatchMC(const MemcacheParser::Command& cmd, std::string_view value, 34 | ConnectionContext* cntx); 35 | 36 | uint32_t shard_count() const { 37 | return shard_set_.size(); 38 | } 39 | 40 | EngineShardSet& shard_set() { 41 | return shard_set_; 42 | } 43 | 44 | util::ProactorPool& proactor_pool() { 45 | return pp_; 46 | } 47 | 48 | private: 49 | void Ping(CmdArgList args, ConnectionContext* cntx); 50 | void Set(CmdArgList args, ConnectionContext* cntx); 51 | void Get(CmdArgList args, ConnectionContext* cntx); 52 | void Debug(CmdArgList args, ConnectionContext* cntx); 53 | 54 | void RegisterCommands(); 55 | 56 | base::VarzValue::Map GetVarzStats(); 57 | 58 | CommandRegistry registry_; 59 | EngineShardSet shard_set_; 60 | util::ProactorPool& pp_; 61 | }; 62 | 63 | } // namespace dfly 64 | -------------------------------------------------------------------------------- /server/memcache_parser.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | #include "server/memcache_parser.h" 5 | 6 | #include 7 | #include 8 | 9 | namespace dfly { 10 | using namespace std; 11 | 12 | namespace { 13 | 14 | pair cmd_map[] = { 15 | {"set", MemcacheParser::SET}, {"add", MemcacheParser::ADD}, 16 | {"replace", MemcacheParser::REPLACE}, {"append", MemcacheParser::APPEND}, 17 | {"prepend", MemcacheParser::PREPEND}, {"cas", MemcacheParser::CAS}, 18 | {"get", MemcacheParser::GET}, {"gets", MemcacheParser::GETS}, 19 | {"gat", MemcacheParser::GAT}, {"gats", MemcacheParser::GATS}, 20 | }; 21 | 22 | MemcacheParser::CmdType From(string_view token) { 23 | for (const auto& k_v : cmd_map) { 24 | if (token == k_v.first) 25 | return k_v.second; 26 | } 27 | return MemcacheParser::INVALID; 28 | } 29 | 30 | MemcacheParser::Result ParseStore(const std::string_view* tokens, unsigned num_tokens, 31 | MemcacheParser::Command* res) { 32 | 33 | unsigned opt_pos = 3; 34 | if (res->type == MemcacheParser::CAS) { 35 | if (num_tokens <= opt_pos) 36 | return MemcacheParser::PARSE_ERROR; 37 | ++opt_pos; 38 | } 39 | 40 | uint32_t flags; 41 | if (!absl::SimpleAtoi(tokens[0], &flags) || !absl::SimpleAtoi(tokens[1], &res->expire_ts) || 42 | !absl::SimpleAtoi(tokens[2], &res->bytes_len)) 43 | return MemcacheParser::BAD_INT; 44 | 45 | if (flags > 0xFFFF) 46 | return MemcacheParser::BAD_INT; 47 | 48 | if (res->type == MemcacheParser::CAS && !absl::SimpleAtoi(tokens[3], &res->cas_unique)) { 49 | return MemcacheParser::BAD_INT; 50 | } 51 | 52 | res->flags = flags; 53 | if (num_tokens == opt_pos + 1) { 54 | if (tokens[opt_pos] == "noreply") { 55 | res->no_reply = true; 56 | } else { 57 | return MemcacheParser::PARSE_ERROR; 58 | } 59 | } else if (num_tokens > opt_pos + 1) { 60 | return MemcacheParser::PARSE_ERROR; 61 | } 62 | 63 | return MemcacheParser::OK; 64 | } 65 | 66 | MemcacheParser::Result ParseRetrieve(const std::string_view* tokens, unsigned num_tokens, 67 | MemcacheParser::Command* res) { 68 | unsigned key_pos = 0; 69 | if (res->type == MemcacheParser::GAT || res->type == MemcacheParser::GATS) { 70 | if (!absl::SimpleAtoi(tokens[0], &res->expire_ts)) { 71 | return MemcacheParser::BAD_INT; 72 | } 73 | ++key_pos; 74 | } 75 | res->key = tokens[key_pos++]; 76 | while (key_pos < num_tokens) { 77 | res->keys_ext.push_back(tokens[key_pos++]); 78 | } 79 | 80 | return MemcacheParser::OK; 81 | } 82 | 83 | } // namespace 84 | 85 | auto MemcacheParser::Parse(string_view str, uint32_t* consumed, Command* res) -> Result { 86 | auto pos = str.find('\n'); 87 | *consumed = 0; 88 | if (pos == string_view::npos) { 89 | // TODO: it's over simplified since we may process gets command that is not limited to 90 | // 300 characters. 91 | return str.size() > 300 ? PARSE_ERROR : INPUT_PENDING; 92 | } 93 | if (pos == 0 || str[pos - 1] != '\r') { 94 | return PARSE_ERROR; 95 | } 96 | *consumed = pos + 1; 97 | 98 | // cas [noreply]\r\n 99 | // get *\r\n 100 | string_view tokens[8]; 101 | unsigned num_tokens = 0; 102 | uint32_t cur = 0; 103 | 104 | while (cur < pos && str[cur] == ' ') 105 | ++cur; 106 | uint32_t s = cur; 107 | for (; cur < pos; ++cur) { 108 | if (str[cur] == ' ' || str[cur] == '\r') { 109 | if (cur != s) { 110 | tokens[num_tokens++] = str.substr(s, cur - s); 111 | if (num_tokens == ABSL_ARRAYSIZE(tokens)) { 112 | ++cur; 113 | s = cur; 114 | break; 115 | } 116 | } 117 | s = cur + 1; 118 | } 119 | } 120 | if (num_tokens == 0) 121 | return PARSE_ERROR; 122 | 123 | while (cur < pos - 1) { 124 | if (str[cur] != ' ') 125 | return PARSE_ERROR; 126 | ++cur; 127 | } 128 | 129 | res->type = From(tokens[0]); 130 | if (res->type == INVALID) { 131 | return UNKNOWN_CMD; 132 | } 133 | 134 | if (res->type <= CAS) { // Store command 135 | if (num_tokens < 5 || tokens[1].size() > 250) { 136 | return MemcacheParser::PARSE_ERROR; 137 | } 138 | 139 | // memcpy(single_key_, tokens[0].data(), tokens[0].size()); // we copy the key 140 | res->key = string_view{tokens[1].data(), tokens[1].size()}; 141 | 142 | return ParseStore(tokens + 2, num_tokens - 2, res); 143 | } 144 | 145 | return ParseRetrieve(tokens + 1, num_tokens - 1, res); 146 | }; 147 | 148 | } // namespace dfly -------------------------------------------------------------------------------- /server/memcache_parser.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | namespace dfly { 11 | 12 | // Memcache parser does not parse value blobs, only the commands. 13 | // The expectation is that the caller will parse the command and 14 | // then will follow up with reading the blob data directly from source. 15 | class MemcacheParser { 16 | public: 17 | enum CmdType { 18 | INVALID = 0, 19 | SET = 1, 20 | ADD = 2, 21 | REPLACE = 3, 22 | APPEND = 4, 23 | PREPEND = 5, 24 | CAS = 6, 25 | 26 | // Retrieval 27 | GET = 10, 28 | GETS = 11, 29 | GAT = 12, 30 | GATS = 13, 31 | 32 | // Delete and INCR 33 | DELETE = 21, 34 | INCR = 22, 35 | DECR = 23, 36 | }; 37 | 38 | struct Command { 39 | CmdType type = INVALID; 40 | std::string_view key; 41 | std::vector keys_ext; 42 | 43 | uint64_t cas_unique = 0; 44 | uint32_t expire_ts = 0; 45 | uint32_t bytes_len = 0; 46 | uint16_t flags = 0; 47 | bool no_reply = false; 48 | }; 49 | 50 | enum Result { 51 | OK, 52 | INPUT_PENDING, 53 | UNKNOWN_CMD, 54 | BAD_INT, 55 | PARSE_ERROR, 56 | }; 57 | 58 | static bool IsStoreCmd(CmdType type) { 59 | return type >= SET && type <= CAS; 60 | } 61 | 62 | Result Parse(std::string_view str, uint32_t* consumed, Command* res); 63 | 64 | private: 65 | }; 66 | 67 | } // namespace dfly -------------------------------------------------------------------------------- /server/memcache_parser_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/memcache_parser.h" 6 | 7 | #include 8 | 9 | #include "absl/strings/str_cat.h" 10 | #include "base/gtest.h" 11 | #include "base/logging.h" 12 | #include "server/test_utils.h" 13 | 14 | using namespace testing; 15 | using namespace std; 16 | namespace dfly { 17 | 18 | class MCParserTest : public testing::Test { 19 | protected: 20 | RedisParser::Result Parse(std::string_view str); 21 | 22 | MemcacheParser parser_; 23 | MemcacheParser::Command cmd_; 24 | uint32_t consumed_; 25 | 26 | unique_ptr stash_; 27 | }; 28 | 29 | 30 | TEST_F(MCParserTest, Basic) { 31 | MemcacheParser::Result st = parser_.Parse("set a 1 20 3\r\n", &consumed_, &cmd_); 32 | EXPECT_EQ(MemcacheParser::OK, st); 33 | EXPECT_EQ("a", cmd_.key); 34 | EXPECT_EQ(1, cmd_.flags); 35 | EXPECT_EQ(20, cmd_.expire_ts); 36 | EXPECT_EQ(3, cmd_.bytes_len); 37 | } 38 | 39 | } // namespace dfly -------------------------------------------------------------------------------- /server/op_status.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | namespace dfly { 10 | 11 | enum class OpStatus : uint16_t { 12 | OK, 13 | KEY_NOTFOUND, 14 | }; 15 | 16 | class OpResultBase { 17 | public: 18 | OpResultBase(OpStatus st = OpStatus::OK) : st_(st) { 19 | } 20 | 21 | constexpr explicit operator bool() const { 22 | return st_ == OpStatus::OK; 23 | } 24 | 25 | OpStatus status() const { 26 | return st_; 27 | } 28 | 29 | bool operator==(OpStatus st) const { 30 | return st_ == st; 31 | } 32 | 33 | bool ok() const { 34 | return st_ == OpStatus::OK; 35 | } 36 | 37 | private: 38 | OpStatus st_; 39 | }; 40 | 41 | template class OpResult : public OpResultBase { 42 | public: 43 | OpResult(V v) : v_(std::move(v)) { 44 | } 45 | 46 | using OpResultBase::OpResultBase; 47 | 48 | const V& value() const { 49 | return v_; 50 | } 51 | 52 | V& value() { 53 | return v_; 54 | } 55 | 56 | V value_or(V v) const { 57 | return status() == OpStatus::OK ? v_ : v; 58 | } 59 | 60 | const V* operator->() const { 61 | return &v_; 62 | } 63 | 64 | private: 65 | V v_; 66 | }; 67 | 68 | template <> class OpResult : public OpResultBase { 69 | public: 70 | using OpResultBase::OpResultBase; 71 | }; 72 | 73 | inline bool operator==(OpStatus st, const OpResultBase& ob) { 74 | return ob.operator==(st); 75 | } 76 | 77 | } // namespace dfly 78 | 79 | namespace std { 80 | 81 | template std::ostream& operator<<(std::ostream& os, const dfly::OpResult& res) { 82 | os << int(res.status()); 83 | return os; 84 | } 85 | 86 | inline std::ostream& operator<<(std::ostream& os, const dfly::OpStatus op) { 87 | os << int(op); 88 | return os; 89 | } 90 | 91 | } // namespace std -------------------------------------------------------------------------------- /server/redis_parser.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | #include "server/redis_parser.h" 5 | 6 | #include 7 | 8 | #include "base/logging.h" 9 | 10 | namespace dfly { 11 | 12 | using namespace std; 13 | 14 | namespace { 15 | 16 | constexpr int kMaxArrayLen = 1024; 17 | constexpr int64_t kMaxBulkLen = 64 * (1ul << 20); // 64MB. 18 | 19 | } // namespace 20 | 21 | auto RedisParser::Parse(Buffer str, uint32_t* consumed, RespExpr::Vec* res) -> Result { 22 | *consumed = 0; 23 | res->clear(); 24 | 25 | if (str.size() < 2) { 26 | return INPUT_PENDING; 27 | } 28 | 29 | if (state_ == CMD_COMPLETE_S) { 30 | state_ = INIT_S; 31 | } 32 | 33 | if (state_ == INIT_S) { 34 | InitStart(str[0], res); 35 | } 36 | 37 | if (!cached_expr_) 38 | cached_expr_ = res; 39 | 40 | while (state_ != CMD_COMPLETE_S) { 41 | last_consumed_ = 0; 42 | switch (state_) { 43 | case ARRAY_LEN_S: 44 | last_result_ = ConsumeArrayLen(str); 45 | break; 46 | case PARSE_ARG_S: 47 | if (str.size() < 4) { 48 | last_result_ = INPUT_PENDING; 49 | } else { 50 | last_result_ = ParseArg(str); 51 | } 52 | break; 53 | case INLINE_S: 54 | DCHECK(parse_stack_.empty()); 55 | last_result_ = ParseInline(str); 56 | break; 57 | case BULK_STR_S: 58 | last_result_ = ConsumeBulk(str); 59 | break; 60 | case FINISH_ARG_S: 61 | HandleFinishArg(); 62 | break; 63 | default: 64 | LOG(FATAL) << "Unexpected state " << int(state_); 65 | } 66 | 67 | *consumed += last_consumed_; 68 | 69 | if (last_result_ != OK) { 70 | break; 71 | } 72 | str.remove_prefix(last_consumed_); 73 | } 74 | 75 | if (last_result_ == INPUT_PENDING) { 76 | StashState(res); 77 | } else if (last_result_ == OK) { 78 | DCHECK(cached_expr_); 79 | if (res != cached_expr_) { 80 | DCHECK(!stash_.empty()); 81 | 82 | *res = *cached_expr_; 83 | } 84 | } 85 | 86 | return last_result_; 87 | } 88 | 89 | void RedisParser::InitStart(uint8_t prefix_b, RespExpr::Vec* res) { 90 | buf_stash_.clear(); 91 | stash_.clear(); 92 | cached_expr_ = res; 93 | parse_stack_.clear(); 94 | last_stashed_level_ = 0; 95 | last_stashed_index_ = 0; 96 | 97 | switch (prefix_b) { 98 | case '$': 99 | state_ = PARSE_ARG_S; 100 | parse_stack_.emplace_back(1, cached_expr_); // expression of length 1. 101 | break; 102 | case '*': 103 | state_ = ARRAY_LEN_S; 104 | break; 105 | default: 106 | state_ = INLINE_S; 107 | break; 108 | } 109 | } 110 | 111 | void RedisParser::StashState(RespExpr::Vec* res) { 112 | if (cached_expr_->empty() && stash_.empty()) { 113 | cached_expr_ = nullptr; 114 | return; 115 | } 116 | 117 | if (cached_expr_ == res) { 118 | stash_.emplace_back(new RespExpr::Vec(*res)); 119 | cached_expr_ = stash_.back().get(); 120 | } 121 | 122 | DCHECK_LT(last_stashed_level_, stash_.size()); 123 | while (true) { 124 | auto& cur = *stash_[last_stashed_level_]; 125 | 126 | for (; last_stashed_index_ < cur.size(); ++last_stashed_index_) { 127 | auto& e = cur[last_stashed_index_]; 128 | if (RespExpr::STRING == e.type) { 129 | Buffer& ebuf = get(e.u); 130 | if (ebuf.empty() && last_stashed_index_ + 1 == cur.size()) 131 | break; 132 | if (!ebuf.empty() && !e.has_support) { 133 | BlobPtr ptr(new uint8_t[ebuf.size()]); 134 | memcpy(ptr.get(), ebuf.data(), ebuf.size()); 135 | ebuf = Buffer{ptr.get(), ebuf.size()}; 136 | buf_stash_.push_back(std::move(ptr)); 137 | e.has_support = true; 138 | } 139 | } 140 | } 141 | 142 | if (last_stashed_level_ + 1 == stash_.size()) 143 | break; 144 | ++last_stashed_level_; 145 | last_stashed_index_ = 0; 146 | } 147 | } 148 | 149 | auto RedisParser::ParseInline(Buffer str) -> Result { 150 | DCHECK(!str.empty()); 151 | 152 | uint8_t* ptr = str.begin(); 153 | uint8_t* end = str.end(); 154 | uint8_t* token_start = ptr; 155 | 156 | if (is_broken_token_) { 157 | while (ptr != end && *ptr > 32) 158 | ++ptr; 159 | 160 | size_t len = ptr - token_start; 161 | 162 | ExtendLastString(Buffer(token_start, len)); 163 | if (ptr != end) { 164 | is_broken_token_ = false; 165 | } 166 | } 167 | 168 | auto is_finish = [&] { return ptr == end || *ptr == '\n'; }; 169 | 170 | while (true) { 171 | while (!is_finish() && *ptr <= 32) { 172 | ++ptr; 173 | } 174 | // We do not test for \r in order to accept 'nc' input. 175 | if (is_finish()) 176 | break; 177 | 178 | DCHECK(!is_broken_token_); 179 | 180 | token_start = ptr; 181 | while (ptr != end && *ptr > 32) 182 | ++ptr; 183 | 184 | cached_expr_->emplace_back(RespExpr::STRING); 185 | cached_expr_->back().u = Buffer{token_start, size_t(ptr - token_start)}; 186 | } 187 | 188 | last_consumed_ = ptr - str.data(); 189 | if (ptr == end) { // we have not finished parsing. 190 | if (ptr[-1] > 32) { 191 | // we stopped in the middle of the token. 192 | is_broken_token_ = true; 193 | } 194 | 195 | return INPUT_PENDING; 196 | } else { 197 | ++last_consumed_; // consume the delimiter as well. 198 | } 199 | state_ = CMD_COMPLETE_S; 200 | 201 | return OK; 202 | } 203 | 204 | auto RedisParser::ParseNum(Buffer str, int64_t* res) -> Result { 205 | if (str.size() < 4) { 206 | return INPUT_PENDING; 207 | } 208 | 209 | char* s = reinterpret_cast(str.data() + 1); 210 | char* pos = reinterpret_cast(memchr(s, '\n', str.size() - 1)); 211 | if (!pos) { 212 | return str.size() < 32 ? INPUT_PENDING : BAD_INT; 213 | } 214 | if (pos[-1] != '\r') { 215 | return BAD_INT; 216 | } 217 | 218 | bool success = absl::SimpleAtoi(std::string_view{s, size_t(pos - s - 1)}, res); 219 | if (!success) { 220 | return BAD_INT; 221 | } 222 | last_consumed_ = (pos - s) + 2; 223 | 224 | return OK; 225 | } 226 | 227 | auto RedisParser::ConsumeArrayLen(Buffer str) -> Result { 228 | int64_t len; 229 | 230 | Result res = ParseNum(str, &len); 231 | switch (res) { 232 | case INPUT_PENDING: 233 | return INPUT_PENDING; 234 | case BAD_INT: 235 | return BAD_ARRAYLEN; 236 | case OK: 237 | if (len < -1 || len > kMaxArrayLen) 238 | return BAD_ARRAYLEN; 239 | break; 240 | default: 241 | LOG(ERROR) << "Unexpected result " << res; 242 | } 243 | 244 | // Already parsed array expression somewhere. Server should accept only single-level expressions. 245 | if (!parse_stack_.empty()) 246 | return BAD_STRING; 247 | 248 | // Similarly if our cached expr is not empty. 249 | if (!cached_expr_->empty()) 250 | return BAD_STRING; 251 | 252 | if (len <= 0) { 253 | cached_expr_->emplace_back(len == -1 ? RespExpr::NIL_ARRAY : RespExpr::ARRAY); 254 | if (len < 0) 255 | cached_expr_->back().u.emplace(nullptr); // nil 256 | else { 257 | static RespVec empty_vec; 258 | cached_expr_->back().u = &empty_vec; 259 | } 260 | state_ = (parse_stack_.empty()) ? CMD_COMPLETE_S : FINISH_ARG_S; 261 | 262 | return OK; 263 | } 264 | 265 | parse_stack_.emplace_back(len, cached_expr_); 266 | DCHECK(cached_expr_->empty()); 267 | state_ = PARSE_ARG_S; 268 | 269 | return OK; 270 | } 271 | 272 | auto RedisParser::ParseArg(Buffer str) -> Result { 273 | char c = str[0]; 274 | if (c == '$') { 275 | int64_t len; 276 | 277 | Result res = ParseNum(str, &len); 278 | switch (res) { 279 | case INPUT_PENDING: 280 | return INPUT_PENDING; 281 | case BAD_INT: 282 | return BAD_ARRAYLEN; 283 | case OK: 284 | if (len < -1 || len > kMaxBulkLen) 285 | return BAD_ARRAYLEN; 286 | break; 287 | default: 288 | LOG(ERROR) << "Unexpected result " << res; 289 | } 290 | 291 | if (len < 0) { 292 | state_ = FINISH_ARG_S; 293 | cached_expr_->emplace_back(RespExpr::NIL); 294 | } else { 295 | cached_expr_->emplace_back(RespExpr::STRING); 296 | bulk_len_ = len; 297 | state_ = BULK_STR_S; 298 | } 299 | cached_expr_->back().u = Buffer{}; 300 | 301 | return OK; 302 | } 303 | 304 | return BAD_BULKLEN; 305 | } 306 | 307 | auto RedisParser::ConsumeBulk(Buffer str) -> Result { 308 | auto& bulk_str = get(cached_expr_->back().u); 309 | 310 | if (str.size() >= bulk_len_ + 2) { 311 | if (str[bulk_len_] != '\r' || str[bulk_len_ + 1] != '\n') { 312 | return BAD_STRING; 313 | } 314 | 315 | if (bulk_len_) { 316 | if (is_broken_token_) { 317 | memcpy(bulk_str.end(), str.data(), bulk_len_); 318 | bulk_str = Buffer{bulk_str.data(), bulk_str.size() + bulk_len_}; 319 | } else { 320 | bulk_str = str.subspan(0, bulk_len_); 321 | } 322 | } 323 | is_broken_token_ = false; 324 | state_ = FINISH_ARG_S; 325 | last_consumed_ = bulk_len_ + 2; 326 | bulk_len_ = 0; 327 | 328 | return OK; 329 | } 330 | 331 | if (str.size() >= 32) { 332 | DCHECK(bulk_len_); 333 | size_t len = std::min(str.size(), bulk_len_); 334 | 335 | if (is_broken_token_) { 336 | memcpy(bulk_str.end(), str.data(), len); 337 | bulk_str = Buffer{bulk_str.data(), bulk_str.size() + len}; 338 | DVLOG(1) << "Extending bulk stash to size " << bulk_str.size(); 339 | } else { 340 | DVLOG(1) << "New bulk stash size " << bulk_len_; 341 | std::unique_ptr nb(new uint8_t[bulk_len_]); 342 | memcpy(nb.get(), str.data(), len); 343 | bulk_str = Buffer{nb.get(), len}; 344 | buf_stash_.emplace_back(move(nb)); 345 | is_broken_token_ = true; 346 | cached_expr_->back().has_support = true; 347 | } 348 | last_consumed_ = len; 349 | bulk_len_ -= len; 350 | } 351 | 352 | return INPUT_PENDING; 353 | } 354 | 355 | void RedisParser::HandleFinishArg() { 356 | DCHECK(!parse_stack_.empty()); 357 | DCHECK_GT(parse_stack_.back().first, 0u); 358 | 359 | while (true) { 360 | --parse_stack_.back().first; 361 | state_ = PARSE_ARG_S; 362 | if (parse_stack_.back().first != 0) 363 | break; 364 | 365 | parse_stack_.pop_back(); // pop 0. 366 | if (parse_stack_.empty()) { 367 | state_ = CMD_COMPLETE_S; 368 | break; 369 | } 370 | cached_expr_ = parse_stack_.back().second; 371 | } 372 | } 373 | 374 | void RedisParser::ExtendLastString(Buffer str) { 375 | DCHECK(!cached_expr_->empty() && cached_expr_->back().type == RespExpr::STRING); 376 | DCHECK(!buf_stash_.empty()); 377 | 378 | Buffer& last_str = get(cached_expr_->back().u); 379 | 380 | DCHECK(last_str.data() == buf_stash_.back().get()); 381 | 382 | std::unique_ptr nb(new uint8_t[last_str.size() + str.size()]); 383 | memcpy(nb.get(), last_str.data(), last_str.size()); 384 | memcpy(nb.get() + last_str.size(), str.data(), str.size()); 385 | last_str = RespExpr::Buffer{nb.get(), last_str.size() + str.size()}; 386 | buf_stash_.back() = std::move(nb); 387 | } 388 | 389 | } // namespace dfly 390 | -------------------------------------------------------------------------------- /server/redis_parser.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | #pragma once 5 | 6 | #include 7 | 8 | #include "resp_expr.h" 9 | 10 | namespace dfly { 11 | 12 | /** 13 | * @brief Zero-copy (best-effort) parser. 14 | * 15 | */ 16 | class RedisParser { 17 | public: 18 | enum Result { 19 | OK, 20 | INPUT_PENDING, 21 | BAD_ARRAYLEN, 22 | BAD_BULKLEN, 23 | BAD_STRING, 24 | BAD_INT 25 | }; 26 | using Buffer = RespExpr::Buffer; 27 | 28 | explicit RedisParser() { 29 | } 30 | 31 | /** 32 | * @brief Parses str into res. "consumed" stores number of bytes consumed from str. 33 | * 34 | * A caller should not invalidate str if the parser returns RESP_OK as long as he continues 35 | * accessing res. However, if parser returns MORE_INPUT a caller may discard consumed 36 | * part of str because parser caches the intermediate state internally according to 'consumed' 37 | * result. 38 | * 39 | * Note: A parser does not always guarantee progress, i.e. if a small buffer was passed it may 40 | * returns MORE_INPUT with consumed == 0. 41 | * 42 | */ 43 | 44 | Result Parse(Buffer str, uint32_t* consumed, RespVec* res); 45 | 46 | size_t stash_size() const { return stash_.size(); } 47 | const std::vector>& stash() const { return stash_;} 48 | 49 | private: 50 | void InitStart(uint8_t prefix_b, RespVec* res); 51 | void StashState(RespVec* res); 52 | 53 | // Skips the first character (*). 54 | Result ConsumeArrayLen(Buffer str); 55 | Result ParseArg(Buffer str); 56 | Result ConsumeBulk(Buffer str); 57 | Result ParseInline(Buffer str); 58 | 59 | // Updates last_consumed_ 60 | Result ParseNum(Buffer str, int64_t* res); 61 | void HandleFinishArg(); 62 | void ExtendLastString(Buffer str); 63 | 64 | enum State : uint8_t { 65 | INIT_S = 0, 66 | INLINE_S, 67 | ARRAY_LEN_S, 68 | PARSE_ARG_S, // Parse [$:+-]string\r\n 69 | BULK_STR_S, 70 | FINISH_ARG_S, 71 | CMD_COMPLETE_S, 72 | }; 73 | 74 | State state_ = INIT_S; 75 | Result last_result_ = OK; 76 | 77 | uint32_t last_consumed_ = 0; 78 | uint32_t bulk_len_ = 0; 79 | uint32_t last_stashed_level_ = 0, last_stashed_index_ = 0; 80 | 81 | // expected expression length, pointer to expression vector. 82 | absl::InlinedVector, 4> parse_stack_; 83 | std::vector> stash_; 84 | 85 | using BlobPtr = std::unique_ptr; 86 | std::vector buf_stash_; 87 | RespVec* cached_expr_ = nullptr; 88 | bool is_broken_token_ = false; 89 | }; 90 | 91 | } // namespace dfly 92 | -------------------------------------------------------------------------------- /server/redis_parser_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/redis_parser.h" 6 | 7 | #include 8 | #include 9 | 10 | #include "absl/strings/str_cat.h" 11 | #include "base/gtest.h" 12 | #include "base/logging.h" 13 | #include "server/test_utils.h" 14 | 15 | using namespace testing; 16 | using namespace std; 17 | namespace dfly { 18 | 19 | MATCHER_P(ArrArg, expected, absl::StrCat(negation ? "is not" : "is", " equal to:\n", expected)) { 20 | if (arg.type != RespExpr::ARRAY) { 21 | *result_listener << "\nWrong type: " << arg.type; 22 | return false; 23 | } 24 | size_t exp_sz = expected; 25 | size_t actual = get(arg.u)->size(); 26 | 27 | if (exp_sz != actual) { 28 | *result_listener << "\nActual size: " << actual; 29 | return false; 30 | } 31 | return true; 32 | } 33 | 34 | class RedisParserTest : public testing::Test { 35 | protected: 36 | RedisParser::Result Parse(std::string_view str); 37 | 38 | RedisParser parser_; 39 | RespExpr::Vec args_; 40 | uint32_t consumed_; 41 | 42 | unique_ptr stash_; 43 | }; 44 | 45 | RedisParser::Result RedisParserTest::Parse(std::string_view str) { 46 | stash_.reset(new uint8_t[str.size()]); 47 | auto* ptr = stash_.get(); 48 | memcpy(ptr, str.data(), str.size()); 49 | return parser_.Parse(RedisParser::Buffer{ptr, str.size()}, &consumed_, &args_); 50 | } 51 | 52 | TEST_F(RedisParserTest, Inline) { 53 | RespExpr e{RespExpr::STRING}; 54 | ASSERT_EQ(RespExpr::STRING, e.type); 55 | 56 | const char kCmd1[] = "KEY VAL\r\n"; 57 | 58 | ASSERT_EQ(RedisParser::OK, Parse(kCmd1)); 59 | EXPECT_EQ(strlen(kCmd1), consumed_); 60 | EXPECT_THAT(args_, ElementsAre(StrArg("KEY"), StrArg("VAL"))); 61 | 62 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("KEY")); 63 | EXPECT_EQ(3, consumed_); 64 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" FOO ")); 65 | EXPECT_EQ(5, consumed_); 66 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" BAR")); 67 | EXPECT_EQ(4, consumed_); 68 | ASSERT_EQ(RedisParser::OK, Parse(" \r\n ")); 69 | EXPECT_EQ(3, consumed_); 70 | EXPECT_THAT(args_, ElementsAre(StrArg("KEY"), StrArg("FOO"), StrArg("BAR"))); 71 | 72 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" 1 2")); 73 | EXPECT_EQ(4, consumed_); 74 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(" 45")); 75 | EXPECT_EQ(3, consumed_); 76 | ASSERT_EQ(RedisParser::OK, Parse("\r\n")); 77 | EXPECT_EQ(2, consumed_); 78 | EXPECT_THAT(args_, ElementsAre(StrArg("1"), StrArg("2"), StrArg("45"))); 79 | 80 | // Empty queries return RESP_OK. 81 | EXPECT_EQ(RedisParser::OK, Parse("\r\n")); 82 | EXPECT_EQ(2, consumed_); 83 | } 84 | 85 | TEST_F(RedisParserTest, InlineEscaping) { 86 | LOG(ERROR) << "TBD: to be compliant with sdssplitargs"; // TODO: 87 | } 88 | 89 | TEST_F(RedisParserTest, Multi1) { 90 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*1\r\n$")); 91 | EXPECT_EQ(4, consumed_); 92 | 93 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("$4\r\nMSET")); 94 | EXPECT_EQ(4, consumed_); 95 | 96 | ASSERT_EQ(RedisParser::OK, Parse("MSET\r\n*2\r\n")); 97 | EXPECT_EQ(6, consumed_); 98 | 99 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse("*2\r\n$3\r\nKEY\r\n$3\r\nVAL")); 100 | EXPECT_EQ(17, consumed_); 101 | 102 | ASSERT_EQ(RedisParser::OK, Parse("VAL\r\n")); 103 | EXPECT_EQ(5, consumed_); 104 | EXPECT_THAT(args_, ElementsAre("KEY", "VAL")); 105 | } 106 | 107 | TEST_F(RedisParserTest, Multi2) { 108 | const char kFirst[] = "*3\r\n$3\r\nSET\r\n$16\r\nkey:"; 109 | const char kSecond[] = "key:000002273458\r\n$3\r\nVXK"; 110 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(kFirst)); 111 | ASSERT_EQ(strlen(kFirst) - 4, consumed_); 112 | ASSERT_EQ(RedisParser::INPUT_PENDING, Parse(kSecond)); 113 | ASSERT_EQ(strlen(kSecond) - 3, consumed_); 114 | ASSERT_EQ(RedisParser::OK, Parse("VXK\r\n*3\r\n$3\r\nSET")); 115 | EXPECT_THAT(args_, ElementsAre("SET", "key:000002273458", "VXK")); 116 | } 117 | 118 | TEST_F(RedisParserTest, InvalidMult1) { 119 | ASSERT_EQ(RedisParser::BAD_BULKLEN, Parse("*2\r\n$3\r\nFOO\r\nBAR\r\n")); 120 | } 121 | 122 | TEST_F(RedisParserTest, Empty) { 123 | ASSERT_EQ(RedisParser::OK, Parse("*2\r\n$0\r\n\r\n$0\r\n\r\n")); 124 | } 125 | 126 | } // namespace dfly 127 | -------------------------------------------------------------------------------- /server/reply_builder.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | #include "server/reply_builder.h" 5 | 6 | #include 7 | #include 8 | 9 | #include "base/logging.h" 10 | 11 | using namespace std; 12 | using absl::StrAppend; 13 | 14 | namespace dfly { 15 | 16 | namespace { 17 | 18 | inline iovec constexpr IoVec(std::string_view s) { 19 | iovec r{const_cast(s.data()), s.size()}; 20 | return r; 21 | } 22 | 23 | constexpr char kCRLF[] = "\r\n"; 24 | constexpr char kErrPref[] = "-ERR "; 25 | constexpr char kSimplePref[] = "+"; 26 | 27 | } // namespace 28 | 29 | BaseSerializer::BaseSerializer(io::Sink* sink) : sink_(sink) { 30 | } 31 | 32 | void BaseSerializer::Send(const iovec* v, uint32_t len) { 33 | if (should_batch_) { 34 | // TODO: to introduce flushing when too much data is batched. 35 | for (unsigned i = 0; i < len; ++i) { 36 | std::string_view src((char*)v[i].iov_base, v[i].iov_len); 37 | DVLOG(2) << "Appending to stream " << sink_ << " " << src; 38 | batch_.append(src.data(), src.size()); 39 | } 40 | return; 41 | } 42 | 43 | error_code ec; 44 | if (batch_.empty()) { 45 | ec = sink_->Write(v, len); 46 | } else { 47 | DVLOG(1) << "Sending batch to stream " << sink_ << "\n" << batch_; 48 | 49 | iovec tmp[len + 1]; 50 | tmp[0].iov_base = batch_.data(); 51 | tmp[0].iov_len = batch_.size(); 52 | copy(v, v + len, tmp + 1); 53 | ec = sink_->Write(tmp, len + 1); 54 | batch_.clear(); 55 | } 56 | 57 | if (ec) { 58 | ec_ = ec; 59 | } 60 | } 61 | 62 | void BaseSerializer::SendDirect(std::string_view raw) { 63 | iovec v = {IoVec(raw)}; 64 | 65 | Send(&v, 1); 66 | } 67 | 68 | void RespSerializer::SendNull() { 69 | constexpr char kNullStr[] = "$-1\r\n"; 70 | 71 | iovec v[] = {IoVec(kNullStr)}; 72 | 73 | Send(v, ABSL_ARRAYSIZE(v)); 74 | } 75 | 76 | void RespSerializer::SendSimpleString(std::string_view str) { 77 | iovec v[3] = {IoVec(kSimplePref), IoVec(str), IoVec(kCRLF)}; 78 | 79 | Send(v, ABSL_ARRAYSIZE(v)); 80 | } 81 | 82 | void RespSerializer::SendBulkString(std::string_view str) { 83 | char tmp[absl::numbers_internal::kFastToBufferSize + 3]; 84 | tmp[0] = '$'; // Format length 85 | char* next = absl::numbers_internal::FastIntToBuffer(uint32_t(str.size()), tmp + 1); 86 | *next++ = '\r'; 87 | *next++ = '\n'; 88 | 89 | std::string_view lenpref{tmp, size_t(next - tmp)}; 90 | 91 | // 3 parts: length, string and CRLF. 92 | iovec v[3] = {IoVec(lenpref), IoVec(str), IoVec(kCRLF)}; 93 | 94 | return Send(v, ABSL_ARRAYSIZE(v)); 95 | } 96 | 97 | void MemcacheSerializer::SendStored() { 98 | SendDirect("STORED\r\n"); 99 | } 100 | 101 | void MemcacheSerializer::SendError() { 102 | SendDirect("ERROR\r\n"); 103 | } 104 | 105 | ReplyBuilder::ReplyBuilder(Protocol protocol, ::io::Sink* sink) : protocol_(protocol) { 106 | if (protocol == Protocol::REDIS) { 107 | serializer_.reset(new RespSerializer(sink)); 108 | } else { 109 | DCHECK(protocol == Protocol::MEMCACHE); 110 | serializer_.reset(new MemcacheSerializer(sink)); 111 | } 112 | } 113 | 114 | void ReplyBuilder::SendStored() { 115 | if (protocol_ == Protocol::REDIS) { 116 | as_resp()->SendSimpleString("OK"); 117 | } else { 118 | as_mc()->SendStored(); 119 | } 120 | } 121 | 122 | void ReplyBuilder::SendMCClientError(string_view str) { 123 | DCHECK(protocol_ == Protocol::MEMCACHE); 124 | 125 | iovec v[] = {IoVec("CLIENT_ERROR"), IoVec(str), IoVec(kCRLF)}; 126 | serializer_->Send(v, ABSL_ARRAYSIZE(v)); 127 | } 128 | 129 | void ReplyBuilder::EndMultilineReply() { 130 | if (protocol_ == Protocol::MEMCACHE) { 131 | serializer_->SendDirect("END\r\n"); 132 | } 133 | } 134 | 135 | void ReplyBuilder::SendError(string_view str) { 136 | DCHECK(protocol_ == Protocol::REDIS); 137 | 138 | if (str[0] == '-') { 139 | iovec v[] = {IoVec(str), IoVec(kCRLF)}; 140 | serializer_->Send(v, ABSL_ARRAYSIZE(v)); 141 | } else { 142 | iovec v[] = {IoVec(kErrPref), IoVec(str), IoVec(kCRLF)}; 143 | serializer_->Send(v, ABSL_ARRAYSIZE(v)); 144 | } 145 | } 146 | 147 | void ReplyBuilder::SendError(OpStatus status) { 148 | DCHECK(protocol_ == Protocol::REDIS); 149 | 150 | switch (status) { 151 | case OpStatus::OK: 152 | SendOk(); 153 | break; 154 | case OpStatus::KEY_NOTFOUND: 155 | SendError("no such key"); 156 | break; 157 | default: 158 | LOG(ERROR) << "Unsupported status " << status; 159 | SendError("Internal error"); 160 | break; 161 | } 162 | } 163 | 164 | void ReplyBuilder::SendGetReply(std::string_view key, uint32_t flags, std::string_view value) { 165 | if (protocol_ == Protocol::REDIS) { 166 | as_resp()->SendBulkString(value); 167 | } else { 168 | string first = absl::StrCat("VALUE ", key, " ", flags, " ", value.size(), "\r\n"); 169 | iovec v[] = {IoVec(first), IoVec(value), IoVec(kCRLF)}; 170 | serializer_->Send(v, ABSL_ARRAYSIZE(v)); 171 | } 172 | } 173 | 174 | void ReplyBuilder::SendGetNotFound() { 175 | if (protocol_ == Protocol::REDIS) { 176 | as_resp()->SendNull(); 177 | } 178 | } 179 | 180 | void ReplyBuilder::SendSimpleStrArr(const std::string_view* arr, uint32_t count) { 181 | CHECK(protocol_ == Protocol::REDIS); 182 | string res = absl::StrCat("*", count, kCRLF); 183 | 184 | for (size_t i = 0; i < count; ++i) { 185 | StrAppend(&res, "+", arr[i], kCRLF); 186 | } 187 | 188 | serializer_->SendDirect(res); 189 | } 190 | 191 | } // namespace dfly 192 | -------------------------------------------------------------------------------- /server/reply_builder.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | #include 5 | 6 | #include "io/sync_stream_interface.h" 7 | #include "server/dfly_protocol.h" 8 | #include "server/op_status.h" 9 | 10 | namespace dfly { 11 | 12 | class BaseSerializer { 13 | public: 14 | explicit BaseSerializer(::io::Sink* sink); 15 | 16 | std::error_code ec() const { 17 | return ec_; 18 | } 19 | 20 | void CloseConnection() { 21 | if (!ec_) 22 | ec_ = std::make_error_code(std::errc::connection_aborted); 23 | } 24 | 25 | // In order to reduce interrupt rate we allow coalescing responses together using 26 | // Batch mode. It is controlled by Connection state machine because it makes sense only 27 | // when pipelined requests are arriving. 28 | void SetBatchMode(bool batch) { 29 | should_batch_ = batch; 30 | } 31 | 32 | //! Sends a string as is without any formatting. raw should be encoded according to the protocol. 33 | void SendDirect(std::string_view str); 34 | 35 | void Send(const iovec* v, uint32_t len); 36 | 37 | private: 38 | ::io::Sink* sink_; 39 | std::error_code ec_; 40 | std::string batch_; 41 | 42 | bool should_batch_ = false; 43 | }; 44 | 45 | class RespSerializer : public BaseSerializer { 46 | public: 47 | RespSerializer(::io::Sink* sink) : BaseSerializer(sink) { 48 | } 49 | 50 | //! See https://redis.io/topics/protocol 51 | void SendSimpleString(std::string_view str); 52 | void SendNull(); 53 | 54 | /// aka "$6\r\nfoobar\r\n" 55 | void SendBulkString(std::string_view str); 56 | }; 57 | 58 | class MemcacheSerializer : public BaseSerializer { 59 | public: 60 | explicit MemcacheSerializer(::io::Sink* sink) : BaseSerializer(sink) { 61 | } 62 | 63 | void SendStored(); 64 | void SendError(); 65 | }; 66 | 67 | class ReplyBuilder { 68 | public: 69 | ReplyBuilder(Protocol protocol, ::io::Sink* stream); 70 | 71 | void SendStored(); 72 | 73 | void SendError(std::string_view str); 74 | void SendError(OpStatus status); 75 | 76 | void SendOk() { 77 | as_resp()->SendSimpleString("OK"); 78 | } 79 | 80 | std::error_code ec() const { 81 | return serializer_->ec(); 82 | } 83 | 84 | void SendMCClientError(std::string_view str); 85 | void EndMultilineReply(); 86 | 87 | void SendSimpleRespString(std::string_view str) { 88 | as_resp()->SendSimpleString(str); 89 | } 90 | 91 | void SendRespBlob(std::string_view str) { 92 | as_resp()->SendDirect(str); 93 | } 94 | 95 | void SendGetReply(std::string_view key, uint32_t flags, std::string_view value); 96 | void SendGetNotFound(); 97 | 98 | void SetBatchMode(bool mode) { 99 | serializer_->SetBatchMode(mode); 100 | } 101 | 102 | // Resp specific. 103 | // This one is prefixed with + and with clrf added automatically to each item.. 104 | void SendSimpleStrArr(const std::string_view* arr, uint32_t count); 105 | 106 | private: 107 | RespSerializer* as_resp() { 108 | return static_cast(serializer_.get()); 109 | } 110 | MemcacheSerializer* as_mc() { 111 | return static_cast(serializer_.get()); 112 | } 113 | 114 | std::unique_ptr serializer_; 115 | Protocol protocol_; 116 | }; 117 | 118 | } // namespace dfly 119 | -------------------------------------------------------------------------------- /server/resp_expr.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/resp_expr.h" 6 | 7 | #include "base/logging.h" 8 | 9 | namespace dfly { 10 | 11 | const char* RespExpr::TypeName(Type t) { 12 | switch (t) { 13 | case STRING: 14 | return "string"; 15 | case INT64: 16 | return "int"; 17 | case ARRAY: 18 | return "array"; 19 | case NIL_ARRAY: 20 | return "nil-array"; 21 | case NIL: 22 | return "nil"; 23 | case ERROR: 24 | return "error"; 25 | } 26 | ABSL_INTERNAL_UNREACHABLE; 27 | } 28 | 29 | } // namespace dfly 30 | 31 | namespace std { 32 | 33 | ostream& operator<<(ostream& os, const dfly::RespExpr& e) { 34 | using dfly::RespExpr; 35 | using dfly::ToSV; 36 | 37 | switch (e.type) { 38 | case RespExpr::INT64: 39 | os << "i" << get(e.u); 40 | break; 41 | case RespExpr::STRING: 42 | os << "'" << ToSV(e.GetBuf()) << "'"; 43 | break; 44 | case RespExpr::NIL: 45 | os << "nil"; 46 | break; 47 | case RespExpr::NIL_ARRAY: 48 | os << "[]"; 49 | break; 50 | case RespExpr::ARRAY: 51 | os << dfly::RespSpan{*get(e.u)}; 52 | break; 53 | case RespExpr::ERROR: 54 | os << "e(" << ToSV(e.GetBuf()) << ")"; 55 | break; 56 | } 57 | 58 | return os; 59 | } 60 | 61 | ostream& operator<<(ostream& os, dfly::RespSpan ras) { 62 | os << "["; 63 | if (!ras.empty()) { 64 | for (size_t i = 0; i < ras.size() - 1; ++i) { 65 | os << ras[i] << ","; 66 | } 67 | os << ras.back(); 68 | } 69 | os << "]"; 70 | 71 | return os; 72 | } 73 | 74 | } // namespace std -------------------------------------------------------------------------------- /server/resp_expr.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | namespace dfly { 14 | 15 | class RespExpr { 16 | public: 17 | using Buffer = absl::Span; 18 | 19 | enum Type : uint8_t { STRING, ARRAY, INT64, NIL, NIL_ARRAY, ERROR }; 20 | 21 | using Vec = std::vector; 22 | Type type; 23 | bool has_support; // whether pointers in this item are supported by external storage. 24 | 25 | std::variant u; 26 | 27 | RespExpr(Type t = NIL) : type(t), has_support(false) { 28 | } 29 | 30 | static Buffer buffer(std::string* s) { 31 | return Buffer{reinterpret_cast(s->data()), s->size()}; 32 | } 33 | 34 | Buffer GetBuf() const { return std::get(u); } 35 | 36 | static const char* TypeName(Type t); 37 | }; 38 | 39 | using RespVec = RespExpr::Vec; 40 | using RespSpan = absl::Span; 41 | 42 | inline std::string_view ToSV(const absl::Span& s) { 43 | return std::string_view{reinterpret_cast(s.data()), s.size()}; 44 | } 45 | 46 | } // namespace dfly 47 | 48 | namespace std { 49 | 50 | ostream& operator<<(ostream& os, const dfly::RespExpr& e); 51 | ostream& operator<<(ostream& os, dfly::RespSpan rspan); 52 | 53 | } // namespace std -------------------------------------------------------------------------------- /server/test_utils.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "server/test_utils.h" 6 | 7 | #include 8 | 9 | #include "base/logging.h" 10 | #include "util/uring/uring_pool.h" 11 | 12 | namespace dfly { 13 | 14 | using namespace testing; 15 | using namespace util; 16 | using namespace std; 17 | 18 | bool RespMatcher::MatchAndExplain(const RespExpr& e, MatchResultListener* listener) const { 19 | if (e.type != type_) { 20 | *listener << "\nWrong type: " << e.type; 21 | return false; 22 | } 23 | 24 | if (type_ == RespExpr::STRING || type_ == RespExpr::ERROR) { 25 | RespExpr::Buffer ebuf = e.GetBuf(); 26 | std::string_view actual{reinterpret_cast(ebuf.data()), ebuf.size()}; 27 | 28 | if (type_ == RespExpr::ERROR && !absl::StrContains(actual, exp_str_)) { 29 | *listener << "Actual does not contain '" << exp_str_ << "'"; 30 | return false; 31 | } 32 | if (type_ == RespExpr::STRING && exp_str_ != actual) { 33 | *listener << "\nActual string: " << actual; 34 | return false; 35 | } 36 | } else if (type_ == RespExpr::INT64) { 37 | auto actual = get(e.u); 38 | if (exp_int_ != actual) { 39 | *listener << "\nActual : " << actual << " expected: " << exp_int_; 40 | return false; 41 | } 42 | } else if (type_ == RespExpr::ARRAY) { 43 | size_t len = get(e.u)->size(); 44 | if (len != size_t(exp_int_)) { 45 | *listener << "Actual length " << len << ", expected: " << exp_int_; 46 | return false; 47 | } 48 | } 49 | 50 | return true; 51 | } 52 | 53 | void RespMatcher::DescribeTo(std::ostream* os) const { 54 | *os << "is "; 55 | switch (type_) { 56 | case RespExpr::STRING: 57 | case RespExpr::ERROR: 58 | *os << exp_str_; 59 | break; 60 | 61 | case RespExpr::INT64: 62 | *os << exp_str_; 63 | break; 64 | default: 65 | *os << "TBD"; 66 | break; 67 | } 68 | } 69 | 70 | void RespMatcher::DescribeNegationTo(std::ostream* os) const { 71 | *os << "is not "; 72 | } 73 | 74 | bool RespTypeMatcher::MatchAndExplain(const RespExpr& e, MatchResultListener* listener) const { 75 | if (e.type != type_) { 76 | *listener << "\nWrong type: " << RespExpr::TypeName(e.type); 77 | return false; 78 | } 79 | 80 | return true; 81 | } 82 | 83 | void RespTypeMatcher::DescribeTo(std::ostream* os) const { 84 | *os << "is " << RespExpr::TypeName(type_); 85 | } 86 | 87 | void RespTypeMatcher::DescribeNegationTo(std::ostream* os) const { 88 | *os << "is not " << RespExpr::TypeName(type_); 89 | } 90 | 91 | void PrintTo(const RespExpr::Vec& vec, std::ostream* os) { 92 | *os << "Vec: ["; 93 | if (!vec.empty()) { 94 | for (size_t i = 0; i < vec.size() - 1; ++i) { 95 | *os << vec[i] << ","; 96 | } 97 | *os << vec.back(); 98 | } 99 | *os << "]\n"; 100 | } 101 | 102 | vector ToIntArr(const RespVec& vec) { 103 | vector res; 104 | for (auto a : vec) { 105 | int64_t val; 106 | std::string_view s = ToSV(a.GetBuf()); 107 | CHECK(absl::SimpleAtoi(s, &val)) << s; 108 | res.push_back(val); 109 | } 110 | 111 | return res; 112 | } 113 | 114 | } // namespace dfly 115 | -------------------------------------------------------------------------------- /server/test_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include "io/io.h" 10 | #include "server/redis_parser.h" 11 | #include "util/proactor_pool.h" 12 | 13 | namespace dfly { 14 | 15 | class RespMatcher { 16 | public: 17 | RespMatcher(std::string_view val, RespExpr::Type t = RespExpr::STRING) : type_(t), exp_str_(val) { 18 | } 19 | 20 | RespMatcher(int64_t val, RespExpr::Type t = RespExpr::INT64) 21 | : type_(t), exp_int_(val) { 22 | } 23 | 24 | using is_gtest_matcher = void; 25 | 26 | bool MatchAndExplain(const RespExpr& e, testing::MatchResultListener*) const; 27 | 28 | void DescribeTo(std::ostream* os) const; 29 | 30 | void DescribeNegationTo(std::ostream* os) const; 31 | 32 | private: 33 | RespExpr::Type type_; 34 | 35 | std::string exp_str_; 36 | int64_t exp_int_; 37 | }; 38 | 39 | class RespTypeMatcher { 40 | public: 41 | RespTypeMatcher(RespExpr::Type type) : type_(type) { 42 | } 43 | 44 | using is_gtest_matcher = void; 45 | 46 | bool MatchAndExplain(const RespExpr& e, testing::MatchResultListener*) const; 47 | 48 | void DescribeTo(std::ostream* os) const; 49 | 50 | void DescribeNegationTo(std::ostream* os) const; 51 | 52 | private: 53 | RespExpr::Type type_; 54 | }; 55 | 56 | inline ::testing::PolymorphicMatcher StrArg(std::string_view str) { 57 | return ::testing::MakePolymorphicMatcher(RespMatcher(str)); 58 | } 59 | 60 | inline ::testing::PolymorphicMatcher ErrArg(std::string_view str) { 61 | return ::testing::MakePolymorphicMatcher(RespMatcher(str, RespExpr::ERROR)); 62 | } 63 | 64 | inline ::testing::PolymorphicMatcher IntArg(int64_t ival) { 65 | return ::testing::MakePolymorphicMatcher(RespMatcher(ival)); 66 | } 67 | 68 | inline ::testing::PolymorphicMatcher ArrLen(size_t len) { 69 | return ::testing::MakePolymorphicMatcher(RespMatcher(len, RespExpr::ARRAY)); 70 | } 71 | 72 | inline ::testing::PolymorphicMatcher ArgType(RespExpr::Type t) { 73 | return ::testing::MakePolymorphicMatcher(RespTypeMatcher(t)); 74 | } 75 | 76 | inline bool operator==(const RespExpr& left, const char* s) { 77 | return left.type == RespExpr::STRING && ToSV(left.GetBuf()) == s; 78 | } 79 | 80 | void PrintTo(const RespExpr::Vec& vec, std::ostream* os); 81 | 82 | MATCHER_P(RespEq, val, "") { 83 | return ::testing::ExplainMatchResult(::testing::ElementsAre(StrArg(val)), arg, result_listener); 84 | } 85 | 86 | } // namespace dfly 87 | -------------------------------------------------------------------------------- /string_set/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(string_set_lib string_set.cc) 2 | cxx_link(string_set_lib base redis_dict) 3 | 4 | cxx_test(string_set_test string_set_lib) 5 | -------------------------------------------------------------------------------- /string_set/string_set.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2022, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "string_set/string_set.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "base/logging.h" 12 | 13 | namespace dfly { 14 | using namespace std; 15 | 16 | constexpr size_t kMinSizeShift = 2; 17 | constexpr size_t kMinSize = 1 << kMinSizeShift; 18 | constexpr size_t kAllowDisplacements = true; 19 | 20 | inline bool CanSetFlat(int offs) { 21 | if (kAllowDisplacements) 22 | return offs < 2; 23 | return offs == 0; 24 | } 25 | 26 | inline uint64_t HashCode(string_view str) { 27 | return absl::Hash{}(str); 28 | } 29 | 30 | StringSet::StringSet(pmr::memory_resource* mr) : entries_(mr) { 31 | } 32 | 33 | StringSet::~StringSet() { 34 | for (auto& entry : entries_) { 35 | if (entry.IsLink()) { 36 | LinkKey* lk = (LinkKey*)entry.get(); 37 | while (lk) { 38 | sdsfree((sds)lk->ptr); 39 | SuperPtr next = lk->next; 40 | Free(lk); 41 | if (next.IsSds()) { 42 | sdsfree((sds)next.get()); 43 | lk = nullptr; 44 | } else { 45 | DCHECK(next.IsLink()); 46 | lk = (LinkKey*)next.get(); 47 | } 48 | } 49 | } else if (!entry.IsEmpty()) { 50 | sdsfree((sds)entry.get()); 51 | } 52 | } 53 | DCHECK_EQ(0u, num_chain_entries_); 54 | } 55 | 56 | void StringSet::Reserve(size_t sz) { 57 | sz = std::min(sz, kMinSize); 58 | 59 | sz = absl::bit_ceil(sz); 60 | capacity_log_ = absl::bit_width(sz); 61 | entries_.reserve(sz); 62 | } 63 | 64 | size_t StringSet::SuperPtr::SetString(std::string_view str) { 65 | sds sdsptr = sdsnewlen(str.data(), str.size()); 66 | ptr = sdsptr; 67 | return sdsAllocSize(sdsptr); 68 | } 69 | 70 | bool StringSet::SuperPtr::Compare(std::string_view str) const { 71 | if (IsEmpty()) 72 | return false; 73 | 74 | sds sp = GetSds(); 75 | return str == string_view{sp, sdslen(sp)}; 76 | } 77 | 78 | bool StringSet::Add(std::string_view str) { 79 | DVLOG(1) << "Add " << absl::CHexEscape(str); 80 | 81 | uint64_t hc = HashCode(str); 82 | 83 | if (entries_.empty()) { 84 | capacity_log_ = kMinSizeShift; 85 | entries_.resize(kMinSize); 86 | auto& e = entries_[BucketId(hc)]; 87 | obj_malloc_used_ += e.SetString(str); 88 | ++size_; 89 | ++num_used_buckets_; 90 | 91 | return true; 92 | } 93 | 94 | uint32_t bucket_id = BucketId(hc); 95 | if (FindAround(str, bucket_id) < 2) 96 | return false; 97 | 98 | DCHECK_LT(bucket_id, entries_.size()); 99 | ++size_; 100 | 101 | // Try insert into flat surface first. Also handle the grow case 102 | // if utilization is too high. 103 | for (unsigned j = 0; j < 2; ++j) { 104 | int offs = FindEmptyAround(bucket_id); 105 | if (CanSetFlat(offs)) { 106 | auto& entry = entries_[bucket_id + offs]; 107 | obj_malloc_used_ += entry.SetString(str); 108 | if (offs != 0) { 109 | entry.SetDisplaced(); 110 | } 111 | ++num_used_buckets_; 112 | return true; 113 | } 114 | 115 | if (size_ < entries_.size()) 116 | break; 117 | 118 | Grow(); 119 | bucket_id = BucketId(hc); 120 | } 121 | 122 | auto& dest = entries_[bucket_id]; 123 | DCHECK(!dest.IsEmpty()); 124 | if (dest.IsDisplaced()) { 125 | sds sptr = dest.GetSds(); 126 | uint32_t nbid = BucketId(sptr); 127 | Link(SuperPtr{sptr}, nbid); 128 | 129 | if (dest.IsSds()) { 130 | obj_malloc_used_ += dest.SetString(str); 131 | } else { 132 | LinkKey* lk = (LinkKey*)dest.get(); 133 | obj_malloc_used_ += lk->SetString(str); 134 | dest.ClearDisplaced(); 135 | } 136 | } else { 137 | LinkKey* lk = NewLink(str, dest); 138 | dest.SetLink(lk); 139 | } 140 | DCHECK(!dest.IsDisplaced()); 141 | return true; 142 | } 143 | 144 | unsigned StringSet::BucketDepth(uint32_t bid) const { 145 | SuperPtr ptr = entries_[bid]; 146 | if (ptr.IsEmpty()) { 147 | return 0; 148 | } 149 | 150 | unsigned res = 1; 151 | while (ptr.IsLink()) { 152 | LinkKey* lk = (LinkKey*)ptr.get(); 153 | ++res; 154 | ptr = lk->next; 155 | DCHECK(!ptr.IsEmpty()); 156 | } 157 | 158 | return res; 159 | } 160 | 161 | auto StringSet::NewLink(std::string_view str, SuperPtr ptr) -> LinkKey* { 162 | LinkAllocator ea(mr()); 163 | LinkKey* lk = ea.allocate(1); 164 | ea.construct(lk); 165 | obj_malloc_used_ += lk->SetString(str); 166 | lk->next = ptr; 167 | ++num_chain_entries_; 168 | 169 | return lk; 170 | } 171 | 172 | #if 0 173 | void StringSet::IterateOverBucket(uint32_t bid, const ItemCb& cb) { 174 | const Entry& e = entries_[bid]; 175 | if (e.IsEmpty()) { 176 | DCHECK(!e.next); 177 | return; 178 | } 179 | cb(e.value); 180 | 181 | const Entry* next = e.next; 182 | while (next) { 183 | cb(next->value); 184 | next = next->next; 185 | } 186 | } 187 | #endif 188 | 189 | inline bool cmpsds(sds sp, string_view str) { 190 | if (sdslen(sp) != str.size()) 191 | return false; 192 | return str.empty() || memcmp(sp, str.data(), str.size()) == 0; 193 | } 194 | 195 | int StringSet::FindAround(string_view str, uint32_t bid) const { 196 | SuperPtr ptr = entries_[bid]; 197 | 198 | while (ptr.IsLink()) { 199 | LinkKey* lk = (LinkKey*)ptr.get(); 200 | sds sp = (sds)lk->get(); 201 | if (cmpsds(sp, str)) 202 | return 0; 203 | ptr = lk->next; 204 | DCHECK(!ptr.IsEmpty()); 205 | } 206 | 207 | if (!ptr.IsEmpty()) { 208 | DCHECK(ptr.IsSds()); 209 | sds sp = (sds)ptr.get(); 210 | if (cmpsds(sp, str)) 211 | return 0; 212 | } 213 | 214 | if (bid && entries_[bid - 1].Compare(str)) { 215 | return -1; 216 | } 217 | 218 | if (bid + 1 < entries_.size() && entries_[bid + 1].Compare(str)) { 219 | return 1; 220 | } 221 | 222 | return 2; 223 | } 224 | 225 | void StringSet::Grow() { 226 | size_t prev_sz = entries_.size(); 227 | entries_.resize(prev_sz * 2); 228 | ++capacity_log_; 229 | 230 | for (int i = prev_sz - 1; i >= 0; --i) { 231 | SuperPtr* current = &entries_[i]; 232 | if (current->IsEmpty()) { 233 | continue; 234 | } 235 | 236 | SuperPtr* prev = nullptr; 237 | while (true) { 238 | SuperPtr next; 239 | LinkKey* lk = nullptr; 240 | sds sp; 241 | 242 | if (current->IsLink()) { 243 | lk = (LinkKey*)current->get(); 244 | sp = (sds)lk->get(); 245 | next = lk->next; 246 | } else { 247 | sp = (sds)current->get(); 248 | } 249 | 250 | uint32_t bid = BucketId(sp); 251 | if (bid != uint32_t(i)) { 252 | int offs = FindEmptyAround(bid); 253 | if (CanSetFlat(offs)) { 254 | auto& dest = entries_[bid + offs]; 255 | DCHECK(!dest.IsLink()); 256 | 257 | dest.ptr = sp; 258 | if (offs != 0) 259 | dest.SetDisplaced(); 260 | if (lk) { 261 | Free(lk); 262 | } 263 | ++num_used_buckets_; 264 | } else { 265 | Link(*current, bid); 266 | } 267 | *current = next; 268 | } else { 269 | current->ClearDisplaced(); 270 | if (lk) { 271 | prev = current; 272 | current = &lk->next; 273 | } 274 | } 275 | if (next.IsEmpty()) 276 | break; 277 | } 278 | 279 | if (prev) { 280 | DCHECK(prev->IsLink()); 281 | LinkKey* lk = (LinkKey*)prev->get(); 282 | if (lk->next.IsEmpty()) { 283 | bool is_displaced = prev->IsDisplaced(); 284 | prev->ptr = lk->get(); 285 | if (is_displaced) { 286 | prev->SetDisplaced(); 287 | } 288 | Free(lk); 289 | } 290 | } 291 | 292 | if (entries_[i].IsEmpty()) { 293 | --num_used_buckets_; 294 | } 295 | } 296 | 297 | #if 0 298 | unsigned cnt = 0; 299 | for (auto ptr : entries_) { 300 | cnt += (!ptr.IsEmpty()); 301 | } 302 | DCHECK_EQ(num_used_buckets_, cnt); 303 | #endif 304 | } 305 | 306 | void StringSet::Link(SuperPtr ptr, uint32_t bid) { 307 | SuperPtr& root = entries_[bid]; 308 | DCHECK(!root.IsEmpty()); 309 | 310 | bool is_root_displaced = root.IsDisplaced(); 311 | 312 | if (is_root_displaced) { 313 | DCHECK_NE(bid, BucketId(root.GetSds())); 314 | } 315 | LinkKey* head; 316 | void* val; 317 | 318 | if (ptr.IsSds()) { 319 | if (is_root_displaced) { 320 | // in that case it's better to put ptr into root and move root data into its correct place. 321 | sds val; 322 | if (root.IsSds()) { 323 | val = (sds)root.get(); 324 | root.ptr = ptr.get(); 325 | } else { 326 | LinkKey* lk = (LinkKey*)root.get(); 327 | val = (sds)lk->get(); 328 | lk->ptr = ptr.get(); 329 | root.ClearDisplaced(); 330 | } 331 | uint32_t nbid = BucketId(val); 332 | DCHECK_NE(nbid, bid); 333 | 334 | Link(SuperPtr{val}, nbid); // Potentially unbounded wave of updates. 335 | return; 336 | } 337 | 338 | LinkAllocator ea(mr()); 339 | head = ea.allocate(1); 340 | ea.construct(head); 341 | val = ptr.get(); 342 | ++num_chain_entries_; 343 | } else { 344 | head = (LinkKey*)ptr.get(); 345 | val = head->get(); 346 | } 347 | 348 | if (root.IsSds()) { 349 | head->ptr = root.get(); 350 | head->next = SuperPtr{val}; 351 | root.SetLink(head); 352 | if (is_root_displaced) { 353 | DCHECK_NE(bid, BucketId((sds)head->ptr)); 354 | root.SetDisplaced(); 355 | } 356 | } else { 357 | DCHECK(root.IsLink()); 358 | LinkKey* chain = (LinkKey*)root.get(); 359 | head->next = chain->next; 360 | head->ptr = val; 361 | chain->next.SetLink(head); 362 | } 363 | } 364 | 365 | #if 0 366 | void StringSet::MoveEntry(Entry* e, uint32_t bid) { 367 | auto& dest = entries_[bid]; 368 | if (IsEmpty(dest)) { 369 | dest.value = std::move(e->value); 370 | Free(e); 371 | return; 372 | } 373 | e->next = dest.next; 374 | dest.next = e; 375 | } 376 | #endif 377 | 378 | int StringSet::FindEmptyAround(uint32_t bid) const { 379 | if (entries_[bid].IsEmpty()) 380 | return 0; 381 | 382 | if (bid + 1 < entries_.size() && entries_[bid + 1].IsEmpty()) 383 | return 1; 384 | 385 | if (bid && entries_[bid - 1].IsEmpty()) 386 | return -1; 387 | 388 | return 2; 389 | } 390 | 391 | uint32_t StringSet::BucketId(sds ptr) const { 392 | string_view sv{ptr, sdslen(ptr)}; 393 | return BucketId(HashCode(sv)); 394 | } 395 | 396 | #if 0 397 | uint32_t StringSet::Scan(uint32_t cursor, const ItemCb& cb) const { 398 | if (capacity_log_ == 0) 399 | return 0; 400 | 401 | uint32_t bucket_id = cursor >> (32 - capacity_log_); 402 | const_iterator it(this, bucket_id); 403 | 404 | if (it.entry_ == nullptr) 405 | return 0; 406 | 407 | bucket_id = it.bucket_id_; // non-empty bucket 408 | do { 409 | cb(*it); 410 | ++it; 411 | } while (it.bucket_id_ == bucket_id); 412 | 413 | if (it.entry_ == nullptr) 414 | return 0; 415 | 416 | if (it.bucket_id_ == bucket_id + 1) { // cover displacement case 417 | // TODO: we could avoid checking computing HC if we explicitly mark displacement. 418 | // we have plenty-metadata to do so. 419 | uint32_t bid = BucketId((*it).HashCode()); 420 | if (bid == it.bucket_id_) { 421 | cb(*it); 422 | ++it; 423 | } 424 | } 425 | 426 | return it.entry_ ? it.bucket_id_ << (32 - capacity_log_) : 0; 427 | } 428 | 429 | bool StringSet::Erase(std::string_view val) { 430 | uint64_t hc = CompactObj::HashCode(val); 431 | uint32_t bid = BucketId(hc); 432 | 433 | Entry* current = &entries_[bid]; 434 | 435 | if (!current->IsEmpty()) { 436 | if (current->value == val) { 437 | current->Reset(); 438 | ShiftLeftIfNeeded(current); 439 | --size_; 440 | return true; 441 | } 442 | 443 | Entry* prev = current; 444 | current = current->next; 445 | while (current) { 446 | if (current->value == val) { 447 | current->Reset(); 448 | prev->next = current->next; 449 | Free(current); 450 | --size_; 451 | return true; 452 | } 453 | prev = current; 454 | current = current->next; 455 | } 456 | } 457 | 458 | auto& prev = entries_[bid - 1]; 459 | // TODO: to mark displacement. 460 | if (bid && !prev.IsEmpty()) { 461 | if (prev.value == val) { 462 | obj_malloc_used_ -= prev.value.MallocUsed(); 463 | 464 | prev.Reset(); 465 | ShiftLeftIfNeeded(&prev); 466 | --size_; 467 | return true; 468 | } 469 | } 470 | 471 | auto& next = entries_[bid + 1]; 472 | if (bid + 1 < entries_.size()) { 473 | if (next.value == val) { 474 | obj_malloc_used_ -= next.value.MallocUsed(); 475 | next.Reset(); 476 | ShiftLeftIfNeeded(&next); 477 | --size_; 478 | return true; 479 | } 480 | } 481 | 482 | return false; 483 | } 484 | 485 | #endif 486 | 487 | void StringSet::iterator::SeekNonEmpty() { 488 | while (bucket_id_ < owner_->entries_.size()) { 489 | if (!owner_->entries_[bucket_id_].IsEmpty()) { 490 | entry_ = &owner_->entries_[bucket_id_]; 491 | return; 492 | } 493 | ++bucket_id_; 494 | } 495 | entry_ = nullptr; 496 | } 497 | 498 | void StringSet::const_iterator::SeekNonEmpty() { 499 | while (bucket_id_ < owner_->entries_.size()) { 500 | if (!owner_->entries_[bucket_id_].IsEmpty()) { 501 | entry_ = &owner_->entries_[bucket_id_]; 502 | return; 503 | } 504 | ++bucket_id_; 505 | } 506 | entry_ = nullptr; 507 | } 508 | 509 | } // namespace dfly 510 | -------------------------------------------------------------------------------- /string_set/string_set.h: -------------------------------------------------------------------------------- 1 | // Copyright 2022, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | #pragma once 5 | 6 | #include 7 | 8 | extern "C" { 9 | #include "examples/redis_dict/sds.h" 10 | } 11 | 12 | namespace dfly { 13 | 14 | // StringSet is a nice but over-optimized data-structure. Probably is not worth it in the first 15 | // place but sometimes the OCD kicks in and one can not resist. 16 | // The advantage of it over redis-dict is smaller meta-data waste. 17 | // dictEntry is 24 bytes, i.e it uses at least 32N bytes where N is the expected length. 18 | // dict requires to allocate dictEntry per each addition in addition to the supplied key. 19 | // It also wastes space in case of a set because it stores a value pointer inside dictEntry. 20 | // To summarize: 21 | // 100% utilized dict uses N*24 + N*8 = 32N bytes not including the key space. 22 | // for 75% utilization (1/0.75 buckets): N*1.33*8 + N*24 = 35N 23 | // 24 | // This class uses 8 bytes per bucket (similarly to dictEntry*) but it used it for both 25 | // links and keys. For most cases, we remove the need for another redirection layer 26 | // and just store the key, so no "dictEntry" allocations occur. 27 | // For those cells that require chaining, the bucket is 28 | // changed in run-time to represent a linked chain. 29 | // Additional feature - in order to to reduce collisions, we insert items into 30 | // neighbour cells but only if they are empty (not chains). This way we reduce the number of 31 | // empty (unused) spaces at full utilization from 36% to ~21%. 32 | // 100% utilized table requires: N*8 + 0.2N*16 = 11.2N bytes or ~20 bytes savings. 33 | // 75% utilization: N*1.33*8 + 0.12N*16 = 13N or ~22 bytes savings per record. 34 | // TODO: to separate hash/compare functions from table logic and make it generic 35 | // with potential replacements of hset/zset data structures. 36 | // static_assert(sizeof(dictEntry) == 24); 37 | 38 | class StringSet { 39 | struct LinkKey; 40 | // we can assume that high 12 bits of user address space 41 | // can be used for tagging. At most 52 bits of address are reserved for 42 | // some configurations, and usually it's 48 bits. 43 | // https://www.kernel.org/doc/html/latest/arm64/memory.html 44 | static constexpr size_t kLinkBit = 1ULL << 52; 45 | static constexpr size_t kDisplaceBit = 1ULL << 53; 46 | static constexpr size_t kTagMask = 4095ULL << 51; // we reserve 12 high bits. 47 | 48 | struct SuperPtr { 49 | void* ptr = nullptr; // 50 | 51 | explicit SuperPtr(void* p = nullptr) : ptr(p) { 52 | } 53 | 54 | bool IsSds() const { 55 | return (uintptr_t(ptr) & kLinkBit) == 0; 56 | } 57 | 58 | bool IsLink() const { 59 | return (uintptr_t(ptr) & kLinkBit) == kLinkBit; 60 | } 61 | 62 | bool IsEmpty() const { 63 | return ptr == nullptr; 64 | } 65 | 66 | void* get() const { 67 | return (void*)(uintptr_t(ptr) & ~kTagMask); 68 | } 69 | 70 | bool IsDisplaced() const { 71 | return (uintptr_t(ptr) & kDisplaceBit) == kDisplaceBit; 72 | } 73 | 74 | // returns usable size. 75 | size_t SetString(std::string_view str); 76 | 77 | void SetLink(LinkKey* lk) { 78 | ptr = (void*)(uintptr_t(lk) | kLinkBit); 79 | } 80 | 81 | bool Compare(std::string_view str) const; 82 | 83 | void SetDisplaced() { 84 | ptr = (void*)(uintptr_t(ptr) | kDisplaceBit); 85 | } 86 | 87 | void ClearDisplaced() { 88 | ptr = (void*)(uintptr_t(ptr) & ~kDisplaceBit); 89 | } 90 | 91 | void Reset() { 92 | ptr = nullptr; 93 | } 94 | 95 | sds GetSds() const { 96 | if (IsSds()) 97 | return (sds)get(); 98 | LinkKey* lk = (LinkKey*)get(); 99 | return (sds)lk->get(); 100 | } 101 | }; 102 | 103 | struct LinkKey : public SuperPtr { 104 | SuperPtr next; // could be LinkKey* or sds. 105 | }; 106 | 107 | static_assert(sizeof(SuperPtr) == 8); 108 | 109 | public: 110 | class iterator; 111 | class const_iterator; 112 | // using ItemCb = std::function; 113 | 114 | StringSet(const StringSet&) = delete; 115 | 116 | explicit StringSet(std::pmr::memory_resource* mr = std::pmr::get_default_resource()); 117 | ~StringSet(); 118 | 119 | StringSet& operator=(StringSet&) = delete; 120 | 121 | void Reserve(size_t sz); 122 | 123 | bool Add(std::string_view str); 124 | 125 | bool Remove(std::string_view str); 126 | 127 | void Erase(iterator it); 128 | 129 | size_t size() const { 130 | return size_; 131 | } 132 | 133 | bool empty() const { 134 | return size_ == 0; 135 | } 136 | 137 | size_t bucket_count() const { 138 | return entries_.size(); 139 | } 140 | 141 | // those that are chained to the entries stored inline in the bucket array. 142 | size_t num_chain_entries() const { 143 | return num_chain_entries_; 144 | } 145 | 146 | size_t num_used_buckets() const { 147 | return num_used_buckets_; 148 | } 149 | 150 | bool Contains(std::string_view val) const; 151 | 152 | bool Erase(std::string_view val); 153 | 154 | iterator begin() { 155 | return iterator{this, 0}; 156 | } 157 | 158 | iterator end() { 159 | return iterator{}; 160 | } 161 | 162 | size_t obj_malloc_used() const { 163 | return obj_malloc_used_; 164 | } 165 | 166 | size_t set_malloc_used() const { 167 | return (num_chain_entries_ + entries_.capacity()) * sizeof(SuperPtr); 168 | } 169 | 170 | /// stable scanning api. has the same guarantees as redis scan command. 171 | /// we avoid doing bit-reverse by using a different function to derive a bucket id 172 | /// from hash values. By using msb part of hash we make it "stable" with respect to 173 | /// rehashes. For example, with table log size 4 (size 16), entries in bucket id 174 | /// 1110 come from hashes 1110XXXXX.... When a table grows to log size 5, 175 | /// these entries can move either to 11100 or 11101. So if we traversed with our cursor 176 | /// range [0000-1110], it's guaranteed that in grown table we do not need to cover again 177 | /// [00000-11100]. Similarly with shrinkage, if a table is shrinked to log size 3, 178 | /// keys from 1110 and 1111 will move to bucket 111. Again, it's guaranteed that we 179 | /// covered the range [000-111] (all keys in that case). 180 | /// Returns: next cursor or 0 if reached the end of scan. 181 | /// cursor = 0 - initiates a new scan. 182 | // uint32_t Scan(uint32_t cursor, const ItemCb& cb) const; 183 | 184 | unsigned BucketDepth(uint32_t bid) const; 185 | 186 | // void IterateOverBucket(uint32_t bid, const ItemCb& cb); 187 | 188 | class iterator { 189 | friend class StringSet; 190 | 191 | public: 192 | iterator() : owner_(nullptr), entry_(nullptr), bucket_id_(0) { 193 | } 194 | 195 | iterator& operator++(); 196 | 197 | bool operator==(const iterator& o) const { 198 | return entry_ == o.entry_; 199 | } 200 | 201 | bool operator!=(const iterator& o) const { 202 | return !(*this == o); 203 | } 204 | 205 | private: 206 | iterator(StringSet* owner, uint32_t bid) : owner_(owner), bucket_id_(bid) { 207 | SeekNonEmpty(); 208 | } 209 | 210 | void SeekNonEmpty(); 211 | 212 | StringSet* owner_ = nullptr; 213 | SuperPtr* entry_ = nullptr; 214 | uint32_t bucket_id_ = 0; 215 | }; 216 | 217 | class const_iterator { 218 | friend class StringSet; 219 | 220 | public: 221 | const_iterator() : owner_(nullptr), entry_(nullptr), bucket_id_(0) { 222 | } 223 | 224 | const_iterator& operator++(); 225 | 226 | const_iterator& operator=(iterator& it) { 227 | owner_ = it.owner_; 228 | entry_ = it.entry_; 229 | bucket_id_ = it.bucket_id_; 230 | 231 | return *this; 232 | } 233 | 234 | bool operator==(const const_iterator& o) const { 235 | return entry_ == o.entry_; 236 | } 237 | 238 | bool operator!=(const const_iterator& o) const { 239 | return !(*this == o); 240 | } 241 | 242 | private: 243 | const_iterator(const StringSet* owner, uint32_t bid) : owner_(owner), bucket_id_(bid) { 244 | SeekNonEmpty(); 245 | } 246 | 247 | void SeekNonEmpty(); 248 | 249 | const StringSet* owner_ = nullptr; 250 | const SuperPtr* entry_ = nullptr; 251 | uint32_t bucket_id_ = 0; 252 | }; 253 | 254 | private: 255 | friend class iterator; 256 | 257 | using LinkAllocator = std::pmr::polymorphic_allocator; 258 | 259 | std::pmr::memory_resource* mr() { 260 | return entries_.get_allocator().resource(); 261 | } 262 | 263 | uint32_t BucketId(uint64_t hash) const { 264 | return hash >> (64 - capacity_log_); 265 | } 266 | 267 | uint32_t BucketId(sds ptr) const; 268 | 269 | // Returns: 2 if no empty spaces found around the bucket. 0, -1, 1 - offset towards 270 | // an empty bucket. 271 | int FindEmptyAround(uint32_t bid) const; 272 | 273 | // returns 2 if no object was found in the vicinity. 274 | // Returns relative offset to bid: 0, -1, 1 if found. 275 | int FindAround(std::string_view str, uint32_t bid) const; 276 | 277 | void Grow(); 278 | 279 | void Link(SuperPtr ptr, uint32_t bid); 280 | /*void MoveEntry(Entry* e, uint32_t bid); 281 | 282 | void ShiftLeftIfNeeded(Entry* root) { 283 | if (root->next) { 284 | root->value = std::move(root->next->value); 285 | Entry* tmp = root->next; 286 | root->next = root->next->next; 287 | Free(tmp); 288 | } 289 | } 290 | */ 291 | void Free(LinkKey* lk) { 292 | mr()->deallocate(lk, sizeof(LinkKey), alignof(LinkKey)); 293 | --num_chain_entries_; 294 | } 295 | 296 | LinkKey* NewLink(std::string_view str, SuperPtr ptr); 297 | 298 | // The rule is - entries can be moved to vicinity as long as they are stored 299 | // "flat", i.e. not into the linked list. The linked list 300 | std::pmr::vector entries_; 301 | size_t obj_malloc_used_ = 0; 302 | uint32_t size_ = 0; 303 | uint32_t num_chain_entries_ = 0; 304 | uint32_t num_used_buckets_ = 0; 305 | unsigned capacity_log_ = 0; 306 | }; 307 | 308 | #if 0 309 | inline StringSet::iterator& StringSet::iterator::operator++() { 310 | if (entry_->next) { 311 | entry_ = entry_->next; 312 | } else { 313 | ++bucket_id_; 314 | SeekNonEmpty(); 315 | } 316 | 317 | return *this; 318 | } 319 | 320 | inline StringSet::const_iterator& StringSet::const_iterator::operator++() { 321 | if (entry_->next) { 322 | entry_ = entry_->next; 323 | } else { 324 | ++bucket_id_; 325 | SeekNonEmpty(); 326 | } 327 | 328 | return *this; 329 | } 330 | #endif 331 | 332 | } // namespace dfly 333 | -------------------------------------------------------------------------------- /string_set/string_set_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2022, Roman Gershman. All rights reserved. 2 | // See LICENSE for licensing terms. 3 | // 4 | 5 | #include "string_set/string_set.h" 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #include "base/gtest.h" 13 | #include "base/logging.h" 14 | 15 | 16 | namespace dfly { 17 | 18 | using namespace std; 19 | 20 | class StringSetTest : public ::testing::Test { 21 | protected: 22 | static void SetUpTestSuite() { 23 | } 24 | 25 | static void TearDownTestSuite() { 26 | } 27 | 28 | StringSet ss_; 29 | }; 30 | 31 | TEST_F(StringSetTest, Basic) { 32 | EXPECT_TRUE(ss_.Add("foo")); 33 | EXPECT_TRUE(ss_.Add("bar")); 34 | EXPECT_FALSE(ss_.Add("foo")); 35 | EXPECT_FALSE(ss_.Add("bar")); 36 | EXPECT_EQ(2, ss_.size()); 37 | } 38 | 39 | TEST_F(StringSetTest, Ex1) { 40 | EXPECT_TRUE(ss_.Add("AA@@@@@@@@@@@@@@")); 41 | EXPECT_TRUE(ss_.Add("AAA@@@@@@@@@@@@@")); 42 | EXPECT_TRUE(ss_.Add("AAAAAAAAA@@@@@@@")); 43 | EXPECT_TRUE(ss_.Add("AAAAAAAAAA@@@@@@")); 44 | EXPECT_TRUE(ss_.Add("AAAAAAAAAAAAAAA@")); 45 | EXPECT_TRUE(ss_.Add("BBBBBAAAAAAAAAAA")); 46 | EXPECT_TRUE(ss_.Add("BBBBBBBBAAAAAAAA")); 47 | EXPECT_TRUE(ss_.Add("CCCCCBBBBBBBBBBB")); 48 | } 49 | 50 | TEST_F(StringSetTest, Many) { 51 | double max_chain_factor = 0; 52 | for (unsigned i = 0; i < 8192; ++i) { 53 | EXPECT_TRUE(ss_.Add(absl::StrCat("xxxxxxxxxxxxxxxxx", i))); 54 | size_t sz = ss_.size(); 55 | bool should_print = (sz == ss_.bucket_count()) || (sz == ss_.bucket_count() * 0.75); 56 | if (should_print) { 57 | double chain_usage = double(ss_.num_chain_entries()) / ss_.size(); 58 | unsigned num_empty = ss_.bucket_count() - ss_.num_used_buckets(); 59 | double empty_factor = double(num_empty) / ss_.bucket_count(); 60 | 61 | LOG(INFO) << "chains: " << 100 * chain_usage << ", empty: " << 100 * empty_factor << "% at " 62 | << ss_.size(); 63 | #if 0 64 | if (ss_.size() == 15) { 65 | for (unsigned i = 0; i < ss_.bucket_count(); ++i) { 66 | LOG(INFO) << "[" << i << "]: " << ss_.BucketDepth(i); 67 | } 68 | /*ss_.IterateOverBucket(93, [this](const CompactObj& co) { 69 | LOG(INFO) << "93->" << (co.HashCode() % ss_.bucket_count()); 70 | });*/ 71 | } 72 | #endif 73 | } 74 | } 75 | EXPECT_EQ(8192, ss_.size()); 76 | 77 | LOG(INFO) << "max chain factor: " << 100 * max_chain_factor << "%"; 78 | /*size_t iter_len = 0; 79 | for (auto it = ss_.begin(); it != ss_.end(); ++it) { 80 | ++iter_len; 81 | } 82 | EXPECT_EQ(iter_len, 512);*/ 83 | } 84 | 85 | #if 0 86 | TEST_F(StringSetTest, IterScan) { 87 | unordered_set actual, expected; 88 | auto insert_actual = [&](const CompactObj& val) { 89 | string tmp; 90 | val.GetString(&tmp); 91 | actual.insert(tmp); 92 | }; 93 | 94 | EXPECT_EQ(0, ss_.Scan(0, insert_actual)); 95 | EXPECT_TRUE(actual.empty()); 96 | 97 | for (unsigned i = 0; i < 512; ++i) { 98 | string s = absl::StrCat("x", i); 99 | expected.insert(s); 100 | EXPECT_TRUE(ss_.Add(s)); 101 | } 102 | 103 | 104 | for (CompactObj& val : ss_) { 105 | insert_actual(val); 106 | } 107 | 108 | EXPECT_EQ(actual, expected); 109 | actual.clear(); 110 | uint32_t cursor = 0; 111 | do { 112 | cursor = ss_.Scan(cursor, insert_actual); 113 | } while (cursor); 114 | EXPECT_EQ(actual, expected); 115 | } 116 | 117 | #endif 118 | 119 | } // namespace dfly --------------------------------------------------------------------------------