├── .github └── workflows │ ├── build.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Cargo.toml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── abi.proto ├── benches └── pubsub.rs ├── bin └── .gitkeep ├── build.rs ├── deny.toml ├── fixtures ├── ca.cert ├── ca.key ├── client.cert ├── client.conf ├── client.key ├── quic_client.conf ├── quic_server.conf ├── server.cert ├── server.conf └── server.key ├── flamegraph.svg ├── src ├── client.rs ├── config.rs ├── error.rs ├── lib.rs ├── network │ ├── frame.rs │ ├── mod.rs │ ├── multiplex │ │ ├── mod.rs │ │ ├── quic_mplex.rs │ │ └── yamux_mplex.rs │ ├── stream.rs │ ├── stream_result.rs │ └── tls.rs ├── pb │ ├── abi.rs │ └── mod.rs ├── server.rs ├── service │ ├── command_service.rs │ ├── mod.rs │ ├── topic.rs │ └── topic_service.rs └── storage │ ├── memory.rs │ ├── mod.rs │ └── sleddb.rs ├── tests └── server.rs └── tools ├── gen_cert.rs └── gen_config.rs /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | build-rust: 13 | strategy: 14 | matrix: 15 | platform: [ubuntu-latest] 16 | runs-on: ${{ matrix.platform }} 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Cache cargo registry 20 | uses: actions/cache@v1 21 | with: 22 | path: ~/.cargo/registry 23 | key: ${{ runner.os }}-cargo-registry 24 | - name: Cache cargo index 25 | uses: actions/cache@v1 26 | with: 27 | path: ~/.cargo/git 28 | key: ${{ runner.os }}-cargo-index 29 | - name: Cache cargo build 30 | uses: actions/cache@v1 31 | with: 32 | path: target 33 | key: ${{ runner.os }}-cargo-build-target 34 | - name: Install stable 35 | uses: actions-rs/toolchain@v1 36 | with: 37 | profile: minimal 38 | toolchain: stable 39 | override: true 40 | - name: Check code format 41 | run: cargo fmt -- --check 42 | - name: Check the package for errors 43 | run: cargo check --all 44 | - name: Lint rust sources 45 | run: cargo clippy --all-targets --all-features --tests --benches -- -D warnings 46 | - name: Run tests 47 | run: cargo test --all-features -- --test-threads=1 --nocapture 48 | - name: Generate docs 49 | run: cargo doc --all-features --no-deps 50 | - name: Deploy docs to github actions 51 | uses: peaceiris/actions-gh-pages@v3 52 | with: 53 | github_token: ${{ secrets.GITHUB_TOKEN }} 54 | publish_dir: ./target/doc 55 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 7 | 8 | jobs: 9 | build: 10 | name: Upload Release Asset 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest] 15 | steps: 16 | - name: Cache cargo registry 17 | uses: actions/cache@v1 18 | with: 19 | path: ~/.cargo/registry 20 | key: ${{ runner.os }}-cargo-registry 21 | - name: Cache cargo index 22 | uses: actions/cache@v1 23 | with: 24 | path: ~/.cargo/git 25 | key: ${{ runner.os }}-cargo-index 26 | - name: Cache cargo build 27 | uses: actions/cache@v1 28 | with: 29 | path: target 30 | key: ${{ runner.os }}-cargo-build-target 31 | - name: Checkout code 32 | uses: actions/checkout@v2 33 | with: 34 | token: ${{ secrets.GH_TOKEN }} 35 | submodules: recursive 36 | - name: Build project 37 | run: | 38 | make build-release 39 | - name: Create Release 40 | id: create_release 41 | uses: actions/create-release@v1 42 | env: 43 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 44 | with: 45 | tag_name: ${{ github.ref }} 46 | release_name: Release ${{ github.ref }} 47 | draft: false 48 | prerelease: false 49 | - name: Upload asset 50 | id: upload-kv-asset 51 | uses: actions/upload-release-asset@v1 52 | env: 53 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 54 | with: 55 | upload_url: ${{ steps.create_release.outputs.upload_url }} 56 | asset_path: ./target/release/kvs 57 | asset_name: kvs 58 | asset_content_type: application/octet-stream 59 | - name: Generate docs 60 | run: cargo doc --all-features --no-deps 61 | - name: Set env 62 | run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV 63 | - name: Deploy docs to github actions 64 | uses: peaceiris/actions-gh-pages@v3 65 | with: 66 | github_token: ${{ secrets.GITHUB_TOKEN }} 67 | publish_dir: ./target/doc/simple_kv 68 | destination_dir: ${{ env.RELEASE_VERSION }} 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | *.out 12 | bin 13 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: false 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v2.3.0 5 | hooks: 6 | - id: check-byte-order-marker 7 | - id: check-case-conflict 8 | - id: check-merge-conflict 9 | - id: check-symlinks 10 | - id: check-yaml 11 | - id: end-of-file-fixer 12 | - id: mixed-line-ending 13 | - id: trailing-whitespace 14 | - repo: https://github.com/psf/black 15 | rev: 19.3b0 16 | hooks: 17 | - id: black 18 | - repo: local 19 | hooks: 20 | - id: cargo-fmt 21 | name: cargo fmt 22 | description: Format files with rustfmt. 23 | entry: bash -c 'cargo fmt -- --check' 24 | language: rust 25 | files: \.rs$ 26 | args: [] 27 | # - id: cargo-deny 28 | # name: cargo deny check 29 | # description: Check cargo depencencies 30 | # entry: bash -c 'cargo deny check' 31 | # language: rust 32 | # files: \.rs$ 33 | # args: [] 34 | - id: cargo-check 35 | name: cargo check 36 | description: Check the package for errors. 37 | entry: bash -c 'cargo check --all' 38 | language: rust 39 | files: \.rs$ 40 | pass_filenames: false 41 | - id: cargo-clippy 42 | name: cargo clippy 43 | description: Lint rust sources 44 | entry: bash -c 'cargo clippy --all-targets --all-features --tests --benches -- -D warnings' 45 | language: rust 46 | files: \.rs$ 47 | pass_filenames: false 48 | - id: cargo-test 49 | name: cargo test 50 | description: unit test for the project 51 | entry: bash -c 'cargo test --all-features --all' 52 | language: rust 53 | files: \.rs$ 54 | pass_filenames: false 55 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "simple-kv" 3 | version = "0.2.0" 4 | edition = "2021" 5 | 6 | [[bin]] 7 | name = "kvs" 8 | path = "src/server.rs" 9 | 10 | [[bin]] 11 | name = "kvc" 12 | path = "src/client.rs" 13 | 14 | [[bin]] 15 | name = "gen_cert" 16 | path = "tools/gen_cert.rs" 17 | 18 | [[bin]] 19 | name = "gen_config" 20 | path = "tools/gen_config.rs" 21 | 22 | [dependencies] 23 | anyhow = "1" # 错误处理 24 | async-trait = "0.1" # 异步 async trait 25 | bytes = "1" # 高效处理网络 buffer 的库 26 | certify = "0.4" # 创建 x509 cert 27 | dashmap = "5" # 并发 HashMap 28 | flate2 = "1" # gzip 压缩 29 | futures = "0.3" # 提供 Stream trait 30 | http = "0.2" # 我们使用 HTTP status code 所以引入这个类型库 31 | opentelemetry-jaeger = "0.16" # opentelemetry jaeger 支持 32 | prost = "0.9" # 处理 protobuf 的代码 33 | rustls-native-certs = "0.5" # 加载本机信任证书 34 | s2n-quic = "1" 35 | serde = { version = "1", features = ["derive"] } # 序列化/反序列化 36 | sled = "0.34" # sled db 37 | thiserror = "1" # 错误定义和处理 38 | tokio = { version = "1", features = ["full" ] } # 异步网络库 39 | tokio-rustls = "0.22" # 处理 TLS 40 | tokio-stream = { version = "0.1", features = ["sync"] } # 处理 stream 41 | tokio-util = { version = "0.7", features = ["compat"]} # tokio 和 futures 的兼容性库 42 | toml = "0.5" # toml 支持 43 | tracing = "0.1" # 日志处理 44 | tracing-appender = "0.2" # 文件日志 45 | tracing-opentelemetry = "0.17" # opentelemetry 支持 46 | tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } # 日志处理 47 | yamux = "0.10" # yamux 多路复用支持 48 | 49 | [dev-dependencies] 50 | criterion = { version = "0.3", features = ["async_futures", "async_tokio", "html_reports"] } # benchmark 51 | rand = "0.8" # 随机数处理 52 | tempfile = "3" # 处理临时目录和临时文件 53 | 54 | [build-dependencies] 55 | prost-build = "0.9" # 编译 protobuf 56 | 57 | [[bench]] 58 | name = "pubsub" 59 | harness = false 60 | 61 | [profile.bench] 62 | debug = true 63 | 64 | [workspace] 65 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rust:latest AS chef 2 | # We only pay the installation cost once, 3 | # it will be cached from the second build onwards 4 | RUN cargo install cargo-chef 5 | 6 | WORKDIR /app 7 | 8 | FROM chef AS planner 9 | COPY . . 10 | RUN cargo chef prepare --recipe-path recipe.json 11 | 12 | FROM chef AS builder 13 | COPY --from=planner /app/recipe.json recipe.json 14 | # Build dependencies - this is the caching Docker layer! 15 | RUN cargo chef cook --release --recipe-path recipe.json 16 | 17 | # Build application 18 | COPY . . 19 | RUN cargo install --path . 20 | 21 | # We do not need the Rust toolchain to run the binary! 22 | FROM gcr.io/distroless/cc-debian11 23 | # FROM alpine:3.14 24 | COPY --from=builder /usr/local/cargo/bin/kvs /usr/local/bin 25 | EXPOSE 9527 26 | CMD [ "/usr/local/bin/kvs" ] 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | RELEASE_TYPE ?= minor 2 | IMAGE = simple-kv 3 | TAG = $(shell git describe --abbrev=0 --tags) 4 | 5 | init: 6 | # do nothing as for now 7 | 8 | start: 9 | @cargo run --bin kvs 10 | 11 | build-proto: 12 | @BUILD_PROTO=1 cargo build 13 | 14 | build-release-local: 15 | @cargo build --release 16 | @cp ~/.target/release/kvs ./bin 17 | 18 | build-release: 19 | @cargo build --release 20 | @cp ./target/release/kvs ./bin 21 | 22 | bump-release: 23 | @cargo release $(RELEASE_TYPE) --no-dev-version --skip-publish --skip-tag 24 | 25 | show-tag: 26 | @git tag -l --format='%(contents)' $(TAG) 27 | 28 | build-docker: 29 | @docker build -t tchen/${IMAGE}:${TAG} . 30 | @docker tag tchen/${IMAGE}:${TAG} tchen/${IMAGE}:latest 31 | 32 | push-docker: 33 | @docker push tchen/${IMAGE}:${TAG} 34 | @docker push tchen/${IMAGE}:latest 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![release](https://github.com/tyrchen/simple-kv/actions/workflows/release.yml/badge.svg)](https://github.com/tyrchen/simple-kv/actions/workflows/release.yml) 2 | 3 | [![build](https://github.com/tyrchen/simple-kv/actions/workflows/build.yml/badge.svg?branch=master)](https://github.com/tyrchen/simple-kv/actions/workflows/build.yml) 4 | 5 | # Simple KV 6 | 7 | 一个简单的 KV server 实现,配合 geektime 上我的 Rust 第一课。 8 | -------------------------------------------------------------------------------- /abi.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package abi; 4 | 5 | // 来自客户端的命令请求 6 | message CommandRequest { 7 | oneof request_data { 8 | Hget hget = 1; 9 | Hgetall hgetall = 2; 10 | Hmget hmget = 3; 11 | Hset hset = 4; 12 | Hmset hmset = 5; 13 | Hdel hdel = 6; 14 | Hmdel hmdel = 7; 15 | Hexist hexist = 8; 16 | Hmexist hmexist = 9; 17 | Subscribe subscribe = 10; 18 | Unsubscribe unsubscribe = 11; 19 | Publish publish = 12; 20 | } 21 | } 22 | 23 | // 服务器的响应 24 | message CommandResponse { 25 | // 状态码;复用 HTTP 2xx/4xx/5xx 状态码 26 | uint32 status = 1; 27 | // 如果不是 2xx,message 里包含详细的信息 28 | string message = 2; 29 | // 成功返回的 values 30 | repeated Value values = 3; 31 | // 成功返回的 kv pairs 32 | repeated Kvpair pairs = 4; 33 | } 34 | 35 | // 从 table 中获取一个 key,返回 value 36 | message Hget { 37 | string table = 1; 38 | string key = 2; 39 | } 40 | 41 | // 从 table 中获取所有的 Kvpair 42 | message Hgetall { string table = 1; } 43 | 44 | // 从 table 中获取一组 key,返回它们的 value 45 | message Hmget { 46 | string table = 1; 47 | repeated string keys = 2; 48 | } 49 | 50 | // 返回的值 51 | message Value { 52 | oneof value { 53 | string string = 1; 54 | bytes binary = 2; 55 | int64 integer = 3; 56 | double float = 4; 57 | bool bool = 5; 58 | } 59 | } 60 | 61 | // 返回的 kvpair 62 | message Kvpair { 63 | string key = 1; 64 | Value value = 2; 65 | } 66 | 67 | // 往 table 里存一个 kvpair, 68 | // 如果 table 不存在就创建这个 table 69 | message Hset { 70 | string table = 1; 71 | Kvpair pair = 2; 72 | } 73 | 74 | // 往 table 中存一组 kvpair, 75 | // 如果 table 不存在就创建这个 table 76 | message Hmset { 77 | string table = 1; 78 | repeated Kvpair pairs = 2; 79 | } 80 | 81 | // 从 table 中删除一个 key,返回它之前的值 82 | message Hdel { 83 | string table = 1; 84 | string key = 2; 85 | } 86 | 87 | // 从 table 中删除一组 key,返回它们之前的值 88 | message Hmdel { 89 | string table = 1; 90 | repeated string keys = 2; 91 | } 92 | 93 | // 查看 key 是否存在 94 | message Hexist { 95 | string table = 1; 96 | string key = 2; 97 | } 98 | 99 | // 查看一组 key 是否存在 100 | message Hmexist { 101 | string table = 1; 102 | repeated string keys = 2; 103 | } 104 | 105 | // subscribe 到某个主题,任何发布到这个主题的数据都会被收到 106 | // 成功后,第一个返回的 CommandResponse,我们返回一个唯一的 subscription id 107 | message Subscribe { string topic = 1; } 108 | 109 | // 取消对某个主题的订阅 110 | message Unsubscribe { 111 | string topic = 1; 112 | uint32 id = 2; 113 | } 114 | 115 | // 发布数据到某个主题 116 | message Publish { 117 | string topic = 1; 118 | repeated Value data = 2; 119 | } 120 | -------------------------------------------------------------------------------- /benches/pubsub.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use criterion::{criterion_group, criterion_main, Criterion}; 3 | use futures::StreamExt; 4 | use rand::prelude::SliceRandom; 5 | use simple_kv::{ 6 | start_server_with_config, start_yamux_client_with_config, AppStream, ClientConfig, 7 | CommandRequest, ServerConfig, StorageConfig, YamuxCtrl, 8 | }; 9 | use std::time::Duration; 10 | use tokio::net::TcpStream; 11 | use tokio::runtime::Builder; 12 | use tokio::time; 13 | use tokio_rustls::client::TlsStream; 14 | use tracing::{info, span}; 15 | use tracing_subscriber::{layer::SubscriberExt, prelude::*, EnvFilter}; 16 | 17 | async fn start_server() -> Result<()> { 18 | let addr = "127.0.0.1:9999"; 19 | let mut config: ServerConfig = toml::from_str(include_str!("../fixtures/server.conf"))?; 20 | config.general.addr = addr.into(); 21 | config.storage = StorageConfig::MemTable; 22 | 23 | tokio::spawn(async move { 24 | start_server_with_config(&config).await.unwrap(); 25 | }); 26 | 27 | Ok(()) 28 | } 29 | 30 | async fn connect() -> Result>> { 31 | let addr = "127.0.0.1:9999"; 32 | let mut config: ClientConfig = toml::from_str(include_str!("../fixtures/client.conf"))?; 33 | config.general.addr = addr.into(); 34 | 35 | Ok(start_yamux_client_with_config(&config).await?) 36 | } 37 | 38 | async fn start_subscribers(topic: &'static str) -> Result<()> { 39 | let mut ctrl = connect().await?; 40 | let stream = ctrl.open_stream().await?; 41 | info!("C(subscriber): stream opened"); 42 | let cmd = CommandRequest::new_subscribe(topic.to_string()); 43 | tokio::spawn(async move { 44 | let mut stream = stream.execute_streaming(&cmd).await.unwrap(); 45 | while let Some(Ok(data)) = stream.next().await { 46 | drop(data); 47 | } 48 | }); 49 | 50 | Ok(()) 51 | } 52 | 53 | async fn start_publishers(topic: &'static str, values: &'static [&'static str]) -> Result<()> { 54 | let mut rng = rand::thread_rng(); 55 | let v = values.choose(&mut rng).unwrap(); 56 | 57 | let mut ctrl = connect().await.unwrap(); 58 | let mut stream = ctrl.open_stream().await.unwrap(); 59 | info!("C(publisher): stream opened"); 60 | 61 | let cmd = CommandRequest::new_publish(topic.to_string(), vec![(*v).into()]); 62 | stream.execute_unary(&cmd).await.unwrap(); 63 | 64 | Ok(()) 65 | } 66 | 67 | fn pubsub(c: &mut Criterion) { 68 | let tracer = opentelemetry_jaeger::new_pipeline() 69 | .with_service_name("kv-bench") 70 | .install_simple() 71 | .unwrap(); 72 | let opentelemetry = tracing_opentelemetry::layer().with_tracer(tracer); 73 | 74 | tracing_subscriber::registry() 75 | .with(EnvFilter::from_default_env()) 76 | .with(opentelemetry) 77 | .init(); 78 | 79 | let root = span!(tracing::Level::INFO, "app_start", work_units = 2); 80 | let _enter = root.enter(); 81 | // 创建 Tokio runtime 82 | let runtime = Builder::new_multi_thread() 83 | .worker_threads(4) 84 | .thread_name("pubsub") 85 | .enable_all() 86 | .build() 87 | .unwrap(); 88 | 89 | let base_str = include_str!("../fixtures/server.conf"); // 891 bytes 90 | 91 | let values: &'static [&'static str] = Box::leak( 92 | vec![ 93 | &base_str[..64], 94 | &base_str[..128], 95 | &base_str[..256], 96 | &base_str[..512], 97 | ] 98 | .into_boxed_slice(), 99 | ); 100 | let topic = "lobby"; 101 | 102 | // 运行服务器和 100 个 subscriber,为测试准备 103 | runtime.block_on(async { 104 | eprint!("preparing server and subscribers"); 105 | start_server().await.unwrap(); 106 | time::sleep(Duration::from_millis(50)).await; 107 | for _ in 0..1000 { 108 | start_subscribers(topic).await.unwrap(); 109 | eprint!("."); 110 | } 111 | eprintln!("Done!"); 112 | }); 113 | 114 | // 进行 benchmark 115 | c.bench_function("publishing", move |b| { 116 | b.to_async(&runtime) 117 | .iter(|| async { start_publishers(topic, values).await }) 118 | }); 119 | } 120 | 121 | criterion_group! { 122 | name = benches; 123 | config = Criterion::default().sample_size(10); 124 | targets = pubsub 125 | } 126 | criterion_main!(benches); 127 | -------------------------------------------------------------------------------- /bin/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyrchen/simple-kv/409fd7a97a1210cdb5ce9106072ee321e15f036d/bin/.gitkeep -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | use std::process::Command; 2 | 3 | fn main() { 4 | let build_enabled = option_env!("BUILD_PROTO") 5 | .map(|v| v == "1") 6 | .unwrap_or(false); 7 | 8 | if !build_enabled { 9 | println!("=== Skipped compiling protos ==="); 10 | return; 11 | } 12 | 13 | let mut config = prost_build::Config::new(); 14 | config.bytes(&["."]); 15 | config.type_attribute(".", "#[derive(PartialOrd)]"); 16 | config 17 | .out_dir("src/pb") 18 | .compile_protos(&["abi.proto"], &["."]) 19 | .unwrap(); 20 | Command::new("cargo") 21 | .args(&["fmt", "--", "src/*.rs"]) 22 | .status() 23 | .expect("cargo fmt failed"); 24 | } 25 | -------------------------------------------------------------------------------- /deny.toml: -------------------------------------------------------------------------------- 1 | # This template contains all of the possible sections and their default values 2 | 3 | # Note that all fields that take a lint level have these possible values: 4 | # * deny - An error will be produced and the check will fail 5 | # * warn - A warning will be produced, but the check will not fail 6 | # * allow - No warning or error will be produced, though in some cases a note 7 | # will be 8 | 9 | # The values provided in this template are the default values that will be used 10 | # when any section or field is not specified in your own configuration 11 | 12 | # If 1 or more target triples (and optionally, target_features) are specified, 13 | # only the specified targets will be checked when running `cargo deny check`. 14 | # This means, if a particular package is only ever used as a target specific 15 | # dependency, such as, for example, the `nix` crate only being used via the 16 | # `target_family = "unix"` configuration, that only having windows targets in 17 | # this list would mean the nix crate, as well as any of its exclusive 18 | # dependencies not shared by any other crates, would be ignored, as the target 19 | # list here is effectively saying which targets you are building for. 20 | targets = [ 21 | # The triple can be any string, but only the target triples built in to 22 | # rustc (as of 1.40) can be checked against actual config expressions 23 | #{ triple = "x86_64-unknown-linux-musl" }, 24 | # You can also specify which target_features you promise are enabled for a 25 | # particular target. target_features are currently not validated against 26 | # the actual valid features supported by the target architecture. 27 | #{ triple = "wasm32-unknown-unknown", features = ["atomics"] }, 28 | ] 29 | 30 | # This section is considered when running `cargo deny check advisories` 31 | # More documentation for the advisories section can be found here: 32 | # https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html 33 | [advisories] 34 | # The path where the advisory database is cloned/fetched into 35 | db-path = "~/.cargo/advisory-db" 36 | # The url(s) of the advisory databases to use 37 | db-urls = ["https://github.com/rustsec/advisory-db"] 38 | # The lint level for security vulnerabilities 39 | vulnerability = "deny" 40 | # The lint level for unmaintained crates 41 | unmaintained = "warn" 42 | # The lint level for crates that have been yanked from their source registry 43 | yanked = "warn" 44 | # The lint level for crates with security notices. Note that as of 45 | # 2019-12-17 there are no security notice advisories in 46 | # https://github.com/rustsec/advisory-db 47 | notice = "warn" 48 | # A list of advisory IDs to ignore. Note that ignored advisories will still 49 | # output a note when they are encountered. 50 | ignore = [ 51 | #"RUSTSEC-0000-0000", 52 | ] 53 | # Threshold for security vulnerabilities, any vulnerability with a CVSS score 54 | # lower than the range specified will be ignored. Note that ignored advisories 55 | # will still output a note when they are encountered. 56 | # * None - CVSS Score 0.0 57 | # * Low - CVSS Score 0.1 - 3.9 58 | # * Medium - CVSS Score 4.0 - 6.9 59 | # * High - CVSS Score 7.0 - 8.9 60 | # * Critical - CVSS Score 9.0 - 10.0 61 | #severity-threshold = 62 | 63 | # This section is considered when running `cargo deny check licenses` 64 | # More documentation for the licenses section can be found here: 65 | # https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html 66 | [licenses] 67 | # The lint level for crates which do not have a detectable license 68 | unlicensed = "allow" 69 | # List of explictly allowed licenses 70 | # See https://spdx.org/licenses/ for list of possible licenses 71 | # [possible values: any SPDX 3.7 short identifier (+ optional exception)]. 72 | allow = [ 73 | "MIT", 74 | "Apache-2.0", 75 | "Apache-2.0 WITH LLVM-exception", 76 | "BSD-3-Clause", 77 | "BSD-2-Clause", 78 | "MPL-2.0", 79 | "Zlib", 80 | "CC0-1.0", 81 | "ISC", 82 | "GPL-3.0" 83 | 84 | ] 85 | # List of explictly disallowed licenses 86 | # See https://spdx.org/licenses/ for list of possible licenses 87 | # [possible values: any SPDX 3.7 short identifier (+ optional exception)]. 88 | deny = [ 89 | #"Nokia", 90 | ] 91 | # Lint level for licenses considered copyleft 92 | copyleft = "warn" 93 | # Blanket approval or denial for OSI-approved or FSF Free/Libre licenses 94 | # * both - The license will be approved if it is both OSI-approved *AND* FSF 95 | # * either - The license will be approved if it is either OSI-approved *OR* FSF 96 | # * osi-only - The license will be approved if is OSI-approved *AND NOT* FSF 97 | # * fsf-only - The license will be approved if is FSF *AND NOT* OSI-approved 98 | # * neither - This predicate is ignored and the default lint level is used 99 | allow-osi-fsf-free = "neither" 100 | # Lint level used when no other predicates are matched 101 | # 1. License isn't in the allow or deny lists 102 | # 2. License isn't copyleft 103 | # 3. License isn't OSI/FSF, or allow-osi-fsf-free = "neither" 104 | default = "deny" 105 | # The confidence threshold for detecting a license from license text. 106 | # The higher the value, the more closely the license text must be to the 107 | # canonical license text of a valid SPDX license file. 108 | # [possible values: any between 0.0 and 1.0]. 109 | confidence-threshold = 0.8 110 | # Allow 1 or more licenses on a per-crate basis, so that particular licenses 111 | # aren't accepted for every possible crate as with the normal allow list 112 | exceptions = [ 113 | # Each entry is the crate and version constraint, and its specific allow 114 | # list 115 | #{ allow = ["Zlib"], name = "adler32", version = "*" }, 116 | ] 117 | 118 | # Some crates don't have (easily) machine readable licensing information, 119 | # adding a clarification entry for it allows you to manually specify the 120 | # licensing information 121 | #[[licenses.clarify]] 122 | # The name of the crate the clarification applies to 123 | #name = "ring" 124 | # The optional version constraint for the crate 125 | #version = "*" 126 | # The SPDX expression for the license requirements of the crate 127 | #expression = "MIT AND ISC AND OpenSSL" 128 | # One or more files in the crate's source used as the "source of truth" for 129 | # the license expression. If the contents match, the clarification will be used 130 | # when running the license check, otherwise the clarification will be ignored 131 | # and the crate will be checked normally, which may produce warnings or errors 132 | # depending on the rest of your configuration 133 | #license-files = [ 134 | # Each entry is a crate relative path, and the (opaque) hash of its contents 135 | #{ path = "LICENSE", hash = 0xbd0eed23 } 136 | #] 137 | 138 | [licenses.private] 139 | # If true, ignores workspace crates that aren't published, or are only 140 | # published to private registries 141 | ignore = false 142 | # One or more private registries that you might publish crates to, if a crate 143 | # is only published to private registries, and ignore is true, the crate will 144 | # not have its license(s) checked 145 | registries = [ 146 | #"https://sekretz.com/registry 147 | ] 148 | 149 | # This section is considered when running `cargo deny check bans`. 150 | # More documentation about the 'bans' section can be found here: 151 | # https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html 152 | [bans] 153 | # Lint level for when multiple versions of the same crate are detected 154 | multiple-versions = "warn" 155 | # Lint level for when a crate version requirement is `*` 156 | wildcards = "allow" 157 | # The graph highlighting used when creating dotgraphs for crates 158 | # with multiple versions 159 | # * lowest-version - The path to the lowest versioned duplicate is highlighted 160 | # * simplest-path - The path to the version with the fewest edges is highlighted 161 | # * all - Both lowest-version and simplest-path are used 162 | highlight = "all" 163 | # List of crates that are allowed. Use with care! 164 | allow = [ 165 | #{ name = "ansi_term", version = "=0.11.0" }, 166 | ] 167 | # List of crates to deny 168 | deny = [ 169 | # Each entry the name of a crate and a version range. If version is 170 | # not specified, all versions will be matched. 171 | #{ name = "ansi_term", version = "=0.11.0" }, 172 | # 173 | # Wrapper crates can optionally be specified to allow the crate when it 174 | # is a direct dependency of the otherwise banned crate 175 | #{ name = "ansi_term", version = "=0.11.0", wrappers = [] }, 176 | ] 177 | # Certain crates/versions that will be skipped when doing duplicate detection. 178 | skip = [ 179 | #{ name = "ansi_term", version = "=0.11.0" }, 180 | ] 181 | # Similarly to `skip` allows you to skip certain crates during duplicate 182 | # detection. Unlike skip, it also includes the entire tree of transitive 183 | # dependencies starting at the specified crate, up to a certain depth, which is 184 | # by default infinite 185 | skip-tree = [ 186 | #{ name = "ansi_term", version = "=0.11.0", depth = 20 }, 187 | ] 188 | 189 | # This section is considered when running `cargo deny check sources`. 190 | # More documentation about the 'sources' section can be found here: 191 | # https://embarkstudios.github.io/cargo-deny/checks/sources/cfg.html 192 | [sources] 193 | # Lint level for what to happen when a crate from a crate registry that is not 194 | # in the allow list is encountered 195 | unknown-registry = "warn" 196 | # Lint level for what to happen when a crate from a git repository that is not 197 | # in the allow list is encountered 198 | unknown-git = "warn" 199 | # List of URLs for allowed crate registries. Defaults to the crates.io index 200 | # if not specified. If it is specified but empty, no registries are allowed. 201 | allow-registry = ["https://github.com/rust-lang/crates.io-index"] 202 | # List of URLs for allowed Git repositories 203 | allow-git = [] 204 | 205 | [sources.allow-org] 206 | # 1 or more github.com organizations to allow git sources for 207 | github = [] 208 | # 1 or more gitlab.com organizations to allow git sources for 209 | gitlab = [] 210 | # 1 or more bitbucket.org organizations to allow git sources for 211 | bitbucket = [] 212 | -------------------------------------------------------------------------------- /fixtures/ca.cert: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBeDCCASqgAwIBAgIJAK/ACf4lHrMIMAUGAytlcDAzMQswCQYDVQQGDAJDTjES 3 | MBAGA1UECgwJQWNtZSBJbmMuMRAwDgYDVQQDDAdBY21lIENBMB4XDTIxMDkyNjAx 4 | MjU1OVoXDTMxMDkyNDAxMjU1OVowMzELMAkGA1UEBgwCQ04xEjAQBgNVBAoMCUFj 5 | bWUgSW5jLjEQMA4GA1UEAwwHQWNtZSBDQTAqMAUGAytlcAMhAO6QG6Ma8p4xEZ0V 6 | 9VUutcHGutlezoR4E3geBYVojMSKo1swWTATBgNVHREEDDAKgghhY21lLmluYzAd 7 | BgNVHQ4EFgQU8kevGAonItSdc8VGWa74jwYUnHgwEgYDVR0TAQH/BAgwBgEB/wIB 8 | EDAPBgNVHQ8BAf8EBQMDBwYAMAUGAytlcANBAIzbgTAiy6SHCQxfJhlQAs1dIYgU 9 | jwyjxoeYKv/jgGoJd9fMBXsD94tEtyOwF6ph6JEJh0SwCLZrQw2OIp/H7QU= 10 | -----END CERTIFICATE----- 11 | -------------------------------------------------------------------------------- /fixtures/ca.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MFMCAQEwBQYDK2VwBCIEIDGUGlTZG5ZZ0NC/qfb+A96ofGsg4pPchKss2VodfQwf 3 | oSMDIQDukBujGvKeMRGdFfVVLrXBxrrZXs6EeBN4HgWFaIzEig== 4 | -----END PRIVATE KEY----- 5 | -------------------------------------------------------------------------------- /fixtures/client.cert: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBXTCCAQ+gAwIBAgIJAKEchNUbV8e4MAUGAytlcDAzMQswCQYDVQQGDAJDTjES 3 | MBAGA1UECgwJQWNtZSBJbmMuMRAwDgYDVQQDDAdBY21lIENBMB4XDTIxMDkyNjAx 4 | MjU1OVoXDTIyMDkyNjAxMjU1OVowPTELMAkGA1UEBgwCQ04xEjAQBgNVBAoMCUFj 5 | bWUgSW5jLjEaMBgGA1UEAwwRYXdlc29tZS1kZXZpY2UtaWQwKjAFBgMrZXADIQDR 6 | 1n/nkKauN1RPwmK+gXe87WIXJw/QCHUOr4GEabT3qaM2MDQwEwYDVR0lBAwwCgYI 7 | KwYBBQUHAwIwDAYDVR0TBAUwAwEBADAPBgNVHQ8BAf8EBQMDB+AAMAUGAytlcANB 8 | APpIPMooTxwQA6aYP3C0ZeEouAq5FrbXihK0d3Wt+TTi8yUoH/miz5oHxYqGan+U 9 | 0EbIBKB+blgYwEPZ6u5Xcws= 10 | -----END CERTIFICATE----- 11 | -------------------------------------------------------------------------------- /fixtures/client.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | addr = '127.0.0.1:9527' 3 | 4 | [tls] 5 | domain = 'kvserver.acme.inc' 6 | ca = """ 7 | -----BEGIN CERTIFICATE-----\r 8 | MIIBeDCCASqgAwIBAgIJAK/ACf4lHrMIMAUGAytlcDAzMQswCQYDVQQGDAJDTjES\r 9 | MBAGA1UECgwJQWNtZSBJbmMuMRAwDgYDVQQDDAdBY21lIENBMB4XDTIxMDkyNjAx\r 10 | MjU1OVoXDTMxMDkyNDAxMjU1OVowMzELMAkGA1UEBgwCQ04xEjAQBgNVBAoMCUFj\r 11 | bWUgSW5jLjEQMA4GA1UEAwwHQWNtZSBDQTAqMAUGAytlcAMhAO6QG6Ma8p4xEZ0V\r 12 | 9VUutcHGutlezoR4E3geBYVojMSKo1swWTATBgNVHREEDDAKgghhY21lLmluYzAd\r 13 | BgNVHQ4EFgQU8kevGAonItSdc8VGWa74jwYUnHgwEgYDVR0TAQH/BAgwBgEB/wIB\r 14 | EDAPBgNVHQ8BAf8EBQMDBwYAMAUGAytlcANBAIzbgTAiy6SHCQxfJhlQAs1dIYgU\r 15 | jwyjxoeYKv/jgGoJd9fMBXsD94tEtyOwF6ph6JEJh0SwCLZrQw2OIp/H7QU=\r 16 | -----END CERTIFICATE-----\r 17 | """ 18 | -------------------------------------------------------------------------------- /fixtures/client.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MFMCAQEwBQYDK2VwBCIEIHR35EEVaczWX7iBttg7LlOAYXzDsfRE1mOXJQjeryOR 3 | oSMDIQDR1n/nkKauN1RPwmK+gXe87WIXJw/QCHUOr4GEabT3qQ== 4 | -----END PRIVATE KEY----- 5 | -------------------------------------------------------------------------------- /fixtures/quic_client.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | addr = '127.0.0.1:9527' 3 | network = 'quic' 4 | 5 | [tls] 6 | domain = 'kvserver.acme.inc' 7 | ca = """ 8 | -----BEGIN CERTIFICATE----- 9 | MIIBeDCCAR6gAwIBAgIBKjAKBggqhkjOPQQDAjAwMRgwFgYDVQQKDA9DcmFiIHdp 10 | ZGdpdHMgU0UxFDASBgNVBAMMC01hc3RlciBDZXJ0MCIYDzE5NzUwMTAxMDAwMDAw 11 | WhgPNDA5NjAxMDEwMDAwMDBaMDAxGDAWBgNVBAoMD0NyYWIgd2lkZ2l0cyBTRTEU 12 | MBIGA1UEAwwLTWFzdGVyIENlcnQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQb 13 | bVPayLOdbKxXB4yB4Vx3Kf2Z89vsUvhmiICsjncRwBEKkP+GjTg1bSEloLvzuha9 14 | 3u78xp2/1ZaeqtVwYgJMoyUwIzAhBgNVHREEGjAYggtxbGF3cy5xbGF3c4IJbG9j 15 | YWxob3N0MAoGCCqGSM49BAMCA0gAMEUCIDrxPoQBu9G/g54f3TKYXj8bO2fdkPD1 16 | PMO712Y3e0eNAiEA9mt1NW6TDPVf+xmUA/swi8gnhlusV2Y1sB4qhDCPr9c= 17 | -----END CERTIFICATE----- 18 | """ 19 | -------------------------------------------------------------------------------- /fixtures/quic_server.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | addr = '0.0.0.0:9527' 3 | network = 'quic' 4 | 5 | [storage] 6 | type = 'MemTable' 7 | 8 | [tls] 9 | cert = """ 10 | -----BEGIN CERTIFICATE----- 11 | MIIBeDCCAR6gAwIBAgIBKjAKBggqhkjOPQQDAjAwMRgwFgYDVQQKDA9DcmFiIHdp 12 | ZGdpdHMgU0UxFDASBgNVBAMMC01hc3RlciBDZXJ0MCIYDzE5NzUwMTAxMDAwMDAw 13 | WhgPNDA5NjAxMDEwMDAwMDBaMDAxGDAWBgNVBAoMD0NyYWIgd2lkZ2l0cyBTRTEU 14 | MBIGA1UEAwwLTWFzdGVyIENlcnQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQb 15 | bVPayLOdbKxXB4yB4Vx3Kf2Z89vsUvhmiICsjncRwBEKkP+GjTg1bSEloLvzuha9 16 | 3u78xp2/1ZaeqtVwYgJMoyUwIzAhBgNVHREEGjAYggtxbGF3cy5xbGF3c4IJbG9j 17 | YWxob3N0MAoGCCqGSM49BAMCA0gAMEUCIDrxPoQBu9G/g54f3TKYXj8bO2fdkPD1 18 | PMO712Y3e0eNAiEA9mt1NW6TDPVf+xmUA/swi8gnhlusV2Y1sB4qhDCPr9c= 19 | -----END CERTIFICATE----- 20 | """ 21 | key = """ 22 | -----BEGIN PRIVATE KEY----- 23 | MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgtZAp9paqkz1vzQSp 24 | tw52t+ZiSKAuJRfB5JnvA6q7+CKhRANCAAQbbVPayLOdbKxXB4yB4Vx3Kf2Z89vs 25 | UvhmiICsjncRwBEKkP+GjTg1bSEloLvzuha93u78xp2/1ZaeqtVwYgJM 26 | -----END PRIVATE KEY----- 27 | """ 28 | 29 | [log] 30 | enable_log_file = false 31 | enable_jaeger = false 32 | log_level = 'info' 33 | path = '/tmp/kv-log' 34 | rotation = 'Daily' 35 | -------------------------------------------------------------------------------- /fixtures/server.cert: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBdzCCASmgAwIBAgIICpy02U2yuPowBQYDK2VwMDMxCzAJBgNVBAYMAkNOMRIw 3 | EAYDVQQKDAlBY21lIEluYy4xEDAOBgNVBAMMB0FjbWUgQ0EwHhcNMjEwOTI2MDEy 4 | NTU5WhcNMjYwOTI1MDEyNTU5WjA6MQswCQYDVQQGDAJDTjESMBAGA1UECgwJQWNt 5 | ZSBJbmMuMRcwFQYDVQQDDA5BY21lIEtWIHNlcnZlcjAqMAUGAytlcAMhAK2Z2AjF 6 | A0uiltNuCvl6EVFl6tpaS/wJYB5IdWT2IISdo1QwUjAcBgNVHREEFTATghFrdnNl 7 | cnZlci5hY21lLmluYzATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMEBTADAQEA 8 | MA8GA1UdDwEB/wQFAwMH4AAwBQYDK2VwA0EASGOmOWFPjbGhXNOmYNCa3lInbgRy 9 | iTNtB/5kElnbKkhKhRU7yQ8HTHWWkyU5WGWbOOIXEtYp+5ERUJC+mzP9Bw== 10 | -----END CERTIFICATE----- 11 | -------------------------------------------------------------------------------- /fixtures/server.conf: -------------------------------------------------------------------------------- 1 | [general] 2 | addr = '0.0.0.0:9527' 3 | 4 | [storage] 5 | type = 'MemTable' 6 | 7 | [tls] 8 | cert = """ 9 | -----BEGIN CERTIFICATE-----\r 10 | MIIBdzCCASmgAwIBAgIICpy02U2yuPowBQYDK2VwMDMxCzAJBgNVBAYMAkNOMRIw\r 11 | EAYDVQQKDAlBY21lIEluYy4xEDAOBgNVBAMMB0FjbWUgQ0EwHhcNMjEwOTI2MDEy\r 12 | NTU5WhcNMjYwOTI1MDEyNTU5WjA6MQswCQYDVQQGDAJDTjESMBAGA1UECgwJQWNt\r 13 | ZSBJbmMuMRcwFQYDVQQDDA5BY21lIEtWIHNlcnZlcjAqMAUGAytlcAMhAK2Z2AjF\r 14 | A0uiltNuCvl6EVFl6tpaS/wJYB5IdWT2IISdo1QwUjAcBgNVHREEFTATghFrdnNl\r 15 | cnZlci5hY21lLmluYzATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMEBTADAQEA\r 16 | MA8GA1UdDwEB/wQFAwMH4AAwBQYDK2VwA0EASGOmOWFPjbGhXNOmYNCa3lInbgRy\r 17 | iTNtB/5kElnbKkhKhRU7yQ8HTHWWkyU5WGWbOOIXEtYp+5ERUJC+mzP9Bw==\r 18 | -----END CERTIFICATE-----\r 19 | """ 20 | key = """ 21 | -----BEGIN PRIVATE KEY-----\r 22 | MFMCAQEwBQYDK2VwBCIEIPMyINaewhXwuTPUufFO2mMt/MvQMHrGDGxgdgfy/kUu\r 23 | oSMDIQCtmdgIxQNLopbTbgr5ehFRZeraWkv8CWAeSHVk9iCEnQ==\r 24 | -----END PRIVATE KEY-----\r 25 | """ 26 | 27 | [log] 28 | enable_log_file = false 29 | enable_jaeger = false 30 | log_level = 'info' 31 | path = '/tmp/kv-log' 32 | rotation = 'Daily' 33 | -------------------------------------------------------------------------------- /fixtures/server.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MFMCAQEwBQYDK2VwBCIEIPMyINaewhXwuTPUufFO2mMt/MvQMHrGDGxgdgfy/kUu 3 | oSMDIQCtmdgIxQNLopbTbgr5ehFRZeraWkv8CWAeSHVk9iCEnQ== 4 | -----END PRIVATE KEY----- 5 | -------------------------------------------------------------------------------- /src/client.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use anyhow::Result; 4 | use futures::StreamExt; 5 | use simple_kv::{ 6 | start_quic_client_with_config, start_yamux_client_with_config, AppStream, ClientConfig, 7 | CommandRequest, KvError, NetworkType, ProstClientStream, 8 | }; 9 | use tokio::{ 10 | io::{AsyncRead, AsyncWrite}, 11 | time, 12 | }; 13 | use tracing::info; 14 | 15 | #[tokio::main] 16 | async fn main() -> Result<()> { 17 | tracing_subscriber::fmt::init(); 18 | let config: ClientConfig = toml::from_str(include_str!("../fixtures/quic_client.conf"))?; 19 | 20 | // 打开一个 yamux ctrl 21 | match config.general.network { 22 | NetworkType::Tcp => { 23 | let ctrl = start_yamux_client_with_config(&config).await?; 24 | process(ctrl).await?; 25 | } 26 | NetworkType::Quic => { 27 | let ctrl = start_quic_client_with_config(&config).await?; 28 | process(ctrl).await?; 29 | } 30 | } 31 | 32 | println!("Done!"); 33 | 34 | Ok(()) 35 | } 36 | 37 | async fn process(mut ctrl: S) -> Result<()> 38 | where 39 | S: AppStream, 40 | T: AsyncRead + AsyncWrite + Unpin + Send + 'static, 41 | { 42 | let channel = "lobby"; 43 | start_publishing(ctrl.open_stream().await?, channel)?; 44 | 45 | let mut stream = ctrl.open_stream().await?; 46 | 47 | // 生成一个 HSET 命令 48 | let cmd = CommandRequest::new_hset("table1", "hello", "world".to_string().into()); 49 | 50 | // 发送 HSET 命令 51 | let data = stream.execute_unary(&cmd).await?; 52 | info!("Got response {:?}", data); 53 | 54 | // 生成一个 Subscribe 命令 55 | let cmd = CommandRequest::new_subscribe(channel); 56 | let mut stream = stream.execute_streaming(&cmd).await?; 57 | let id = stream.id; 58 | start_unsubscribe(ctrl.open_stream().await?, channel, id)?; 59 | 60 | while let Some(Ok(data)) = stream.next().await { 61 | println!("Got published data: {:?}", data); 62 | } 63 | 64 | Ok(()) 65 | } 66 | 67 | fn start_publishing(mut stream: ProstClientStream, name: &str) -> Result<(), KvError> 68 | where 69 | S: AsyncRead + AsyncWrite + Unpin + Send + 'static, 70 | { 71 | let cmd = CommandRequest::new_publish(name, vec![1.into(), 2.into(), "hello".into()]); 72 | tokio::spawn(async move { 73 | time::sleep(Duration::from_millis(1000)).await; 74 | let res = stream.execute_unary(&cmd).await.unwrap(); 75 | println!("Finished publishing: {:?}", res); 76 | }); 77 | 78 | Ok(()) 79 | } 80 | 81 | fn start_unsubscribe( 82 | mut stream: ProstClientStream, 83 | name: &str, 84 | id: u32, 85 | ) -> Result<(), KvError> 86 | where 87 | S: AsyncRead + AsyncWrite + Unpin + Send + 'static, 88 | { 89 | let cmd = CommandRequest::new_unsubscribe(name, id as _); 90 | tokio::spawn(async move { 91 | time::sleep(Duration::from_millis(2000)).await; 92 | let res = stream.execute_unary(&cmd).await.unwrap(); 93 | println!("Finished unsubscribing: {:?}", res); 94 | }); 95 | 96 | Ok(()) 97 | } 98 | -------------------------------------------------------------------------------- /src/config.rs: -------------------------------------------------------------------------------- 1 | use crate::KvError; 2 | use serde::{Deserialize, Serialize}; 3 | use std::fs; 4 | 5 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] 6 | pub struct ServerConfig { 7 | pub general: GeneralConfig, 8 | pub storage: StorageConfig, 9 | pub tls: ServerTlsConfig, 10 | pub log: LogConfig, 11 | } 12 | 13 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] 14 | pub struct ClientConfig { 15 | pub general: GeneralConfig, 16 | pub tls: ClientTlsConfig, 17 | } 18 | 19 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] 20 | pub struct GeneralConfig { 21 | pub addr: String, 22 | #[serde(default)] 23 | pub network: NetworkType, 24 | } 25 | 26 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] 27 | #[serde(rename_all = "snake_case")] 28 | pub enum NetworkType { 29 | Tcp, 30 | Quic, 31 | } 32 | 33 | impl Default for NetworkType { 34 | fn default() -> Self { 35 | NetworkType::Tcp 36 | } 37 | } 38 | 39 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] 40 | pub struct LogConfig { 41 | pub enable_log_file: bool, 42 | pub enable_jaeger: bool, 43 | pub log_level: String, 44 | pub path: String, 45 | pub rotation: RotationConfig, 46 | } 47 | 48 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] 49 | pub enum RotationConfig { 50 | Hourly, 51 | Daily, 52 | Never, 53 | } 54 | 55 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] 56 | #[serde(tag = "type", content = "args")] 57 | pub enum StorageConfig { 58 | MemTable, 59 | SledDb(String), 60 | } 61 | 62 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] 63 | pub struct ServerTlsConfig { 64 | pub cert: String, 65 | pub key: String, 66 | pub ca: Option, 67 | } 68 | 69 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] 70 | pub struct ClientTlsConfig { 71 | pub domain: String, 72 | pub identity: Option<(String, String)>, 73 | pub ca: Option, 74 | } 75 | 76 | impl ServerConfig { 77 | pub fn load(path: &str) -> Result { 78 | let config = fs::read_to_string(path)?; 79 | let config: Self = toml::from_str(&config)?; 80 | Ok(config) 81 | } 82 | } 83 | 84 | impl ClientConfig { 85 | pub fn load(path: &str) -> Result { 86 | let config = fs::read_to_string(path)?; 87 | let config: Self = toml::from_str(&config)?; 88 | Ok(config) 89 | } 90 | } 91 | 92 | #[cfg(test)] 93 | mod tests { 94 | use super::*; 95 | 96 | #[test] 97 | fn server_config_should_be_loaded() { 98 | let result: Result = 99 | toml::from_str(include_str!("../fixtures/server.conf")); 100 | assert!(result.is_ok()); 101 | } 102 | 103 | #[test] 104 | fn client_config_should_be_loaded() { 105 | let result: Result = 106 | toml::from_str(include_str!("../fixtures/client.conf")); 107 | assert!(result.is_ok()); 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum KvError { 5 | #[error("Not found: {0}")] 6 | NotFound(String), 7 | #[error("Frame is larger than max size")] 8 | FrameError, 9 | #[error("Command is invalid: `{0}`")] 10 | InvalidCommand(String), 11 | #[error("Cannot convert value {0} to {1}")] 12 | ConvertError(String, &'static str), 13 | #[error("Cannot process command {0} with table: {1}, key: {2}. Error: {}")] 14 | StorageError(&'static str, String, String, String), 15 | #[error("Certificate parse error: error to load {0} {0}")] 16 | CertifcateParseError(&'static str, &'static str), 17 | 18 | #[error("Failed to encode protobuf message")] 19 | EncodeError(#[from] prost::EncodeError), 20 | #[error("Failed to decode protobuf message")] 21 | DecodeError(#[from] prost::DecodeError), 22 | #[error("Failed to access sled db")] 23 | SledError(#[from] sled::Error), 24 | #[error("I/O error")] 25 | IoError(#[from] std::io::Error), 26 | #[error("TLS error")] 27 | TlsError(#[from] tokio_rustls::rustls::TLSError), 28 | #[error("Yamux Connection error")] 29 | YamuxConnectionError(#[from] yamux::ConnectionError), 30 | #[error("Parse config error")] 31 | ConfigError(#[from] toml::de::Error), 32 | #[error("Quic connection error")] 33 | QuicConnectionError(#[from] s2n_quic::connection::Error), 34 | 35 | #[error("Internal error: {0}")] 36 | Internal(String), 37 | } 38 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | mod config; 2 | mod error; 3 | mod network; 4 | mod pb; 5 | mod service; 6 | mod storage; 7 | 8 | use std::{net::SocketAddr, str::FromStr}; 9 | 10 | pub use config::*; 11 | pub use error::KvError; 12 | pub use network::*; 13 | pub use pb::abi::*; 14 | pub use service::*; 15 | pub use storage::*; 16 | 17 | use anyhow::Result; 18 | use s2n_quic::{client::Connect, Client, Server}; 19 | use tokio::net::{TcpListener, TcpStream}; 20 | use tokio_rustls::client; 21 | use tokio_util::compat::FuturesAsyncReadCompatExt; 22 | use tracing::{info, instrument, span}; 23 | 24 | /// 通过配置创建 KV 服务器 25 | #[instrument(skip_all)] 26 | pub async fn start_server_with_config(config: &ServerConfig) -> Result<()> { 27 | let addr = &config.general.addr; 28 | match config.general.network { 29 | NetworkType::Tcp => { 30 | let acceptor = TlsServerAcceptor::new( 31 | &config.tls.cert, 32 | &config.tls.key, 33 | config.tls.ca.as_deref(), 34 | )?; 35 | 36 | match &config.storage { 37 | StorageConfig::MemTable => { 38 | start_tls_server(addr, MemTable::new(), acceptor).await? 39 | } 40 | StorageConfig::SledDb(path) => { 41 | start_tls_server(addr, SledDb::new(path), acceptor).await? 42 | } 43 | }; 44 | } 45 | NetworkType::Quic => { 46 | match &config.storage { 47 | StorageConfig::MemTable => { 48 | start_quic_server(addr, MemTable::new(), &config.tls).await? 49 | } 50 | StorageConfig::SledDb(path) => { 51 | start_quic_server(addr, SledDb::new(path), &config.tls).await? 52 | } 53 | }; 54 | } 55 | } 56 | 57 | Ok(()) 58 | } 59 | 60 | /// 通过配置创建 KV 客户端 61 | #[instrument(skip_all)] 62 | pub async fn start_yamux_client_with_config( 63 | config: &ClientConfig, 64 | ) -> Result>> { 65 | let addr = &config.general.addr; 66 | let tls = &config.tls; 67 | 68 | let identity = tls.identity.as_ref().map(|(c, k)| (c.as_str(), k.as_str())); 69 | let connector = TlsClientConnector::new(&tls.domain, identity, tls.ca.as_deref())?; 70 | let stream = TcpStream::connect(addr).await?; 71 | let stream = connector.connect(stream).await?; 72 | 73 | // 打开一个 stream 74 | Ok(YamuxCtrl::new_client(stream, None)) 75 | } 76 | 77 | #[instrument(skip_all)] 78 | pub async fn start_quic_client_with_config(config: &ClientConfig) -> Result { 79 | let addr = SocketAddr::from_str(&config.general.addr)?; 80 | let tls = &config.tls; 81 | 82 | let client = Client::builder() 83 | .with_tls(tls.ca.as_ref().unwrap().as_str())? 84 | .with_io("0.0.0.0:0")? 85 | .start() 86 | .map_err(|e| anyhow::anyhow!("Failed to start client. Error: {}", e))?; 87 | 88 | let connect = Connect::new(addr).with_server_name("localhost"); 89 | let mut conn = client.connect(connect).await?; 90 | 91 | // ensure the connection doesn't time out with inactivity 92 | conn.keep_alive(true)?; 93 | 94 | Ok(QuicCtrl::new(conn)) 95 | } 96 | 97 | async fn start_quic_server( 98 | addr: &str, 99 | store: Store, 100 | tls_config: &ServerTlsConfig, 101 | ) -> Result<()> { 102 | let service: Service = ServiceInner::new(store).into(); 103 | let mut listener = Server::builder() 104 | .with_tls((tls_config.cert.as_str(), tls_config.key.as_str()))? 105 | .with_io(addr)? 106 | .start() 107 | .map_err(|e| anyhow::anyhow!("Failed to start server. Error: {}", e))?; 108 | 109 | info!("Start listening on {addr}"); 110 | 111 | loop { 112 | let root = span!(tracing::Level::INFO, "server_process"); 113 | let _enter = root.enter(); 114 | 115 | if let Some(mut conn) = listener.accept().await { 116 | info!("Client {} connected", conn.remote_addr()?); 117 | let svc = service.clone(); 118 | 119 | tokio::spawn(async move { 120 | while let Ok(Some(stream)) = conn.accept_bidirectional_stream().await { 121 | info!( 122 | "Accepted stream from {}", 123 | stream.connection().remote_addr()? 124 | ); 125 | let svc1 = svc.clone(); 126 | tokio::spawn(async move { 127 | let stream = ProstServerStream::new(stream, svc1.clone()); 128 | stream.process().await.unwrap(); 129 | }); 130 | } 131 | Ok::<(), anyhow::Error>(()) 132 | }); 133 | } 134 | } 135 | } 136 | 137 | async fn start_tls_server( 138 | addr: &str, 139 | store: Store, 140 | acceptor: TlsServerAcceptor, 141 | ) -> Result<()> { 142 | let service: Service = ServiceInner::new(store).into(); 143 | let listener = TcpListener::bind(addr).await?; 144 | info!("Start listening on {}", addr); 145 | loop { 146 | let root = span!(tracing::Level::INFO, "server_process"); 147 | let _enter = root.enter(); 148 | let tls = acceptor.clone(); 149 | let (stream, addr) = listener.accept().await?; 150 | info!("Client {:?} connected", addr); 151 | 152 | let svc = service.clone(); 153 | tokio::spawn(async move { 154 | let stream = tls.accept(stream).await.unwrap(); 155 | YamuxCtrl::new_server(stream, None, move |stream| { 156 | let svc1 = svc.clone(); 157 | async move { 158 | let stream = ProstServerStream::new(stream.compat(), svc1.clone()); 159 | stream.process().await.unwrap(); 160 | Ok(()) 161 | } 162 | }); 163 | }); 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /src/network/frame.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Write}; 2 | 3 | use crate::{CommandRequest, CommandResponse, KvError}; 4 | use bytes::{Buf, BufMut, BytesMut}; 5 | use flate2::{read::GzDecoder, write::GzEncoder, Compression}; 6 | use prost::Message; 7 | use tokio::io::{AsyncRead, AsyncReadExt}; 8 | use tracing::debug; 9 | 10 | /// 长度整个占用 4 个字节 11 | pub const LEN_LEN: usize = 4; 12 | /// 长度占 31 bit,所以最大的 frame 是 2G 13 | const MAX_FRAME: usize = 2 * 1024 * 1024 * 1024; 14 | /// 如果 payload 超过了 1436 字节,就做压缩 15 | const COMPRESSION_LIMIT: usize = 1436; 16 | /// 代表压缩的 bit(整个长度 4 字节的最高位) 17 | const COMPRESSION_BIT: usize = 1 << 31; 18 | 19 | /// 处理 Frame 的 encode/decode 20 | pub trait FrameCoder 21 | where 22 | Self: Message + Sized + Default, 23 | { 24 | /// 把一个 Message encode 成一个 frame 25 | fn encode_frame(&self, buf: &mut BytesMut) -> Result<(), KvError> { 26 | let size = self.encoded_len(); 27 | 28 | if size > MAX_FRAME { 29 | return Err(KvError::FrameError); 30 | } 31 | 32 | // 我们先写入长度,如果需要压缩,再重写压缩后的长度 33 | buf.put_u32(size as _); 34 | 35 | if size > COMPRESSION_LIMIT { 36 | let mut buf1 = Vec::with_capacity(size); 37 | self.encode(&mut buf1)?; 38 | 39 | // BytesMut 支持逻辑上的 split(之后还能 unsplit) 40 | // 所以我们先把长度这 4 字节拿走,清除 41 | let payload = buf.split_off(LEN_LEN); 42 | buf.clear(); 43 | 44 | // 处理 gzip 压缩,具体可以参考 flate2 文档 45 | let mut encoder = GzEncoder::new(payload.writer(), Compression::default()); 46 | encoder.write_all(&buf1[..])?; 47 | 48 | // 压缩完成后,从 gzip encoder 中把 BytesMut 再拿回来 49 | let payload = encoder.finish()?.into_inner(); 50 | debug!("Encode a frame: size {}({})", size, payload.len()); 51 | 52 | // 写入压缩后的长度 53 | buf.put_u32((payload.len() | COMPRESSION_BIT) as _); 54 | 55 | // 把 BytesMut 再合并回来 56 | buf.unsplit(payload); 57 | 58 | Ok(()) 59 | } else { 60 | self.encode(buf)?; 61 | Ok(()) 62 | } 63 | } 64 | 65 | /// 把一个完整的 frame decode 成一个 Message 66 | fn decode_frame(buf: &mut BytesMut) -> Result { 67 | // 先取 4 字节,从中拿出长度和 compression bit 68 | let header = buf.get_u32() as usize; 69 | let (len, compressed) = decode_header(header); 70 | debug!("Got a frame: msg len {}, compressed {}", len, compressed); 71 | 72 | if compressed { 73 | // 解压缩 74 | let mut decoder = GzDecoder::new(&buf[..len]); 75 | let mut buf1 = Vec::with_capacity(len * 2); 76 | decoder.read_to_end(&mut buf1)?; 77 | buf.advance(len); 78 | 79 | // decode 成相应的消息 80 | Ok(Self::decode(&buf1[..buf1.len()])?) 81 | } else { 82 | let msg = Self::decode(&buf[..len])?; 83 | buf.advance(len); 84 | Ok(msg) 85 | } 86 | } 87 | } 88 | 89 | impl FrameCoder for CommandRequest {} 90 | impl FrameCoder for CommandResponse {} 91 | 92 | pub fn decode_header(header: usize) -> (usize, bool) { 93 | let len = header & !COMPRESSION_BIT; 94 | let compressed = header & COMPRESSION_BIT == COMPRESSION_BIT; 95 | (len, compressed) 96 | } 97 | 98 | /// 从 stream 中读取一个完整的 frame 99 | pub async fn read_frame(stream: &mut S, buf: &mut BytesMut) -> Result<(), KvError> 100 | where 101 | S: AsyncRead + Unpin + Send, 102 | { 103 | let header = stream.read_u32().await? as usize; 104 | let (len, _compressed) = decode_header(header); 105 | // 如果没有这么大的内存,就分配至少一个 frame 的内存,保证它可用 106 | buf.reserve(LEN_LEN + len); 107 | buf.put_u32(header as _); 108 | // advance_mut 是 unsafe 的原因是,从当前位置 pos 到 pos + len, 109 | // 这段内存目前没有初始化。我们就是为了 reserve 这段内存,然后从 stream 110 | // 里读取,读取完,它就是初始化的。所以,我们这么用是安全的 111 | unsafe { buf.advance_mut(len) }; 112 | stream.read_exact(&mut buf[LEN_LEN..]).await?; 113 | Ok(()) 114 | } 115 | 116 | #[cfg(test)] 117 | mod tests { 118 | use super::*; 119 | use crate::{utils::DummyStream, Value}; 120 | use bytes::Bytes; 121 | 122 | #[test] 123 | fn command_request_encode_decode_should_work() { 124 | let mut buf = BytesMut::new(); 125 | 126 | let cmd = CommandRequest::new_hdel("t1", "k1"); 127 | cmd.encode_frame(&mut buf).unwrap(); 128 | 129 | // 最高位没设置 130 | assert!(!is_compressed(&buf)); 131 | 132 | let cmd1 = CommandRequest::decode_frame(&mut buf).unwrap(); 133 | assert_eq!(cmd, cmd1); 134 | } 135 | 136 | #[test] 137 | fn command_response_encode_decode_should_work() { 138 | let mut buf = BytesMut::new(); 139 | 140 | let values: Vec = vec![1.into(), "hello".into(), b"data".into()]; 141 | let res: CommandResponse = values.into(); 142 | res.encode_frame(&mut buf).unwrap(); 143 | 144 | // 最高位没设置 145 | assert!(!is_compressed(&buf)); 146 | 147 | let res1 = CommandResponse::decode_frame(&mut buf).unwrap(); 148 | assert_eq!(res, res1); 149 | } 150 | 151 | #[test] 152 | fn command_response_compressed_encode_decode_should_work() { 153 | let mut buf = BytesMut::new(); 154 | 155 | let value: Value = Bytes::from(vec![0u8; COMPRESSION_LIMIT + 1]).into(); 156 | let res: CommandResponse = value.into(); 157 | res.encode_frame(&mut buf).unwrap(); 158 | 159 | // 最高位设置了 160 | assert!(is_compressed(&buf)); 161 | 162 | let res1 = CommandResponse::decode_frame(&mut buf).unwrap(); 163 | assert_eq!(res, res1); 164 | } 165 | 166 | #[tokio::test] 167 | async fn read_frame_should_work() { 168 | let mut buf = BytesMut::new(); 169 | let cmd = CommandRequest::new_hdel("t1", "k1"); 170 | cmd.encode_frame(&mut buf).unwrap(); 171 | let mut stream = DummyStream { buf }; 172 | 173 | let mut data = BytesMut::new(); 174 | read_frame(&mut stream, &mut data).await.unwrap(); 175 | 176 | let cmd1 = CommandRequest::decode_frame(&mut data).unwrap(); 177 | assert_eq!(cmd, cmd1); 178 | } 179 | 180 | fn is_compressed(data: &[u8]) -> bool { 181 | if let [v] = data[..1] { 182 | v >> 7 == 1 183 | } else { 184 | false 185 | } 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /src/network/mod.rs: -------------------------------------------------------------------------------- 1 | mod frame; 2 | mod multiplex; 3 | mod stream; 4 | mod stream_result; 5 | mod tls; 6 | 7 | pub use frame::{read_frame, FrameCoder}; 8 | pub use multiplex::{AppStream, QuicCtrl, YamuxCtrl}; 9 | pub use stream::ProstStream; 10 | pub use stream_result::StreamResult; 11 | pub use tls::{TlsClientConnector, TlsServerAcceptor}; 12 | 13 | use crate::{CommandRequest, CommandResponse, KvError, Service, Storage}; 14 | use futures::{SinkExt, StreamExt}; 15 | use tokio::io::{AsyncRead, AsyncWrite}; 16 | use tracing::{info, warn}; 17 | 18 | /// 处理服务器端的某个 accept 下来的 socket 的读写 19 | pub struct ProstServerStream { 20 | inner: ProstStream, 21 | service: Service, 22 | } 23 | 24 | /// 处理客户端 socket 的读写 25 | pub struct ProstClientStream { 26 | inner: ProstStream, 27 | } 28 | 29 | impl ProstServerStream 30 | where 31 | S: AsyncRead + AsyncWrite + Unpin + Send + 'static, 32 | Store: Storage, 33 | { 34 | pub fn new(stream: S, service: Service) -> Self { 35 | Self { 36 | inner: ProstStream::new(stream), 37 | service, 38 | } 39 | } 40 | 41 | pub async fn process(mut self) -> Result<(), KvError> { 42 | let stream = &mut self.inner; 43 | while let Some(Ok(cmd)) = stream.next().await { 44 | info!("Got a new command: {:?}", cmd); 45 | let mut res = self.service.execute(cmd); 46 | while let Some(data) = res.next().await { 47 | if let Err(e) = stream.send(&data).await { 48 | warn!("Failed to send response: {e:?}"); 49 | } 50 | } 51 | } 52 | // info!("Client {:?} disconnected", self.addr); 53 | Ok(()) 54 | } 55 | } 56 | 57 | impl ProstClientStream 58 | where 59 | S: AsyncRead + AsyncWrite + Unpin + Send + 'static, 60 | { 61 | pub fn new(stream: S) -> Self { 62 | Self { 63 | inner: ProstStream::new(stream), 64 | } 65 | } 66 | 67 | pub async fn execute_unary( 68 | &mut self, 69 | cmd: &CommandRequest, 70 | ) -> Result { 71 | let stream = &mut self.inner; 72 | stream.send(cmd).await?; 73 | 74 | match stream.next().await { 75 | Some(v) => v, 76 | None => Err(KvError::Internal("Didn't get any response".into())), 77 | } 78 | } 79 | 80 | pub async fn execute_streaming(self, cmd: &CommandRequest) -> Result { 81 | let mut stream = self.inner; 82 | 83 | stream.send(cmd).await?; 84 | stream.close().await?; 85 | 86 | StreamResult::new(stream).await 87 | } 88 | } 89 | 90 | #[cfg(test)] 91 | pub mod utils { 92 | use anyhow::Result; 93 | use bytes::{BufMut, BytesMut}; 94 | use std::{cmp::min, task::Poll}; 95 | use tokio::io::{AsyncRead, AsyncWrite}; 96 | 97 | #[derive(Default)] 98 | pub struct DummyStream { 99 | pub buf: BytesMut, 100 | } 101 | 102 | impl AsyncRead for DummyStream { 103 | fn poll_read( 104 | self: std::pin::Pin<&mut Self>, 105 | _cx: &mut std::task::Context<'_>, 106 | buf: &mut tokio::io::ReadBuf<'_>, 107 | ) -> Poll> { 108 | let this = self.get_mut(); 109 | let len = min(buf.capacity(), this.buf.len()); 110 | let data = this.buf.split_to(len); 111 | buf.put_slice(&data); 112 | Poll::Ready(Ok(())) 113 | } 114 | } 115 | 116 | impl AsyncWrite for DummyStream { 117 | fn poll_write( 118 | self: std::pin::Pin<&mut Self>, 119 | _cx: &mut std::task::Context<'_>, 120 | buf: &[u8], 121 | ) -> Poll> { 122 | self.get_mut().buf.put_slice(buf); 123 | Poll::Ready(Ok(buf.len())) 124 | } 125 | 126 | fn poll_flush( 127 | self: std::pin::Pin<&mut Self>, 128 | _cx: &mut std::task::Context<'_>, 129 | ) -> Poll> { 130 | Poll::Ready(Ok(())) 131 | } 132 | 133 | fn poll_shutdown( 134 | self: std::pin::Pin<&mut Self>, 135 | _cx: &mut std::task::Context<'_>, 136 | ) -> Poll> { 137 | Poll::Ready(Ok(())) 138 | } 139 | } 140 | } 141 | 142 | #[cfg(test)] 143 | mod tests { 144 | use std::net::SocketAddr; 145 | 146 | use super::*; 147 | use crate::{assert_res_ok, MemTable, ServiceInner, Value}; 148 | use anyhow::Result; 149 | use bytes::Bytes; 150 | use tokio::net::{TcpListener, TcpStream}; 151 | 152 | #[tokio::test] 153 | async fn client_server_basic_communication_should_work() -> anyhow::Result<()> { 154 | let addr = start_server().await?; 155 | 156 | let stream = TcpStream::connect(addr).await?; 157 | let mut client = ProstClientStream::new(stream); 158 | 159 | // 发送 HSET,等待回应 160 | 161 | let cmd = CommandRequest::new_hset("t1", "k1", "v1".into()); 162 | let res = client.execute_unary(&cmd).await.unwrap(); 163 | 164 | // 第一次 HSET 服务器应该返回 None 165 | assert_res_ok(&res, &[Value::default()], &[]); 166 | 167 | // 再发一个 HSET 168 | let cmd = CommandRequest::new_hget("t1", "k1"); 169 | let res = client.execute_unary(&cmd).await?; 170 | 171 | // 服务器应该返回上一次的结果 172 | assert_res_ok(&res, &["v1".into()], &[]); 173 | 174 | // 发一个 SUBSCRIBE 175 | let cmd = CommandRequest::new_subscribe("chat"); 176 | let res = client.execute_streaming(&cmd).await?; 177 | let id = res.id; 178 | assert!(id > 0); 179 | 180 | Ok(()) 181 | } 182 | 183 | #[tokio::test] 184 | async fn client_server_compression_should_work() -> anyhow::Result<()> { 185 | let addr = start_server().await?; 186 | 187 | let stream = TcpStream::connect(addr).await?; 188 | let mut client = ProstClientStream::new(stream); 189 | 190 | let v: Value = Bytes::from(vec![0u8; 16384]).into(); 191 | let cmd = CommandRequest::new_hset("t2", "k2", v.clone()); 192 | let res = client.execute_unary(&cmd).await?; 193 | 194 | assert_res_ok(&res, &[Value::default()], &[]); 195 | 196 | let cmd = CommandRequest::new_hget("t2", "k2"); 197 | let res = client.execute_unary(&cmd).await?; 198 | 199 | assert_res_ok(&res, &[v], &[]); 200 | 201 | Ok(()) 202 | } 203 | 204 | async fn start_server() -> Result { 205 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 206 | let addr = listener.local_addr().unwrap(); 207 | 208 | tokio::spawn(async move { 209 | loop { 210 | let (stream, _) = listener.accept().await.unwrap(); 211 | let service: Service = ServiceInner::new(MemTable::new()).into(); 212 | let server = ProstServerStream::new(stream, service); 213 | tokio::spawn(server.process()); 214 | } 215 | }); 216 | 217 | Ok(addr) 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /src/network/multiplex/mod.rs: -------------------------------------------------------------------------------- 1 | mod quic_mplex; 2 | mod yamux_mplex; 3 | 4 | pub use quic_mplex::*; 5 | pub use yamux_mplex::*; 6 | 7 | use crate::{KvError, ProstClientStream}; 8 | use async_trait::async_trait; 9 | 10 | #[async_trait] 11 | pub trait AppStream { 12 | type InnerStream; 13 | async fn open_stream(&mut self) -> Result, KvError>; 14 | } 15 | -------------------------------------------------------------------------------- /src/network/multiplex/quic_mplex.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use s2n_quic::{stream::BidirectionalStream, Connection}; 3 | use tracing::instrument; 4 | 5 | use crate::{AppStream, KvError, ProstClientStream}; 6 | 7 | pub struct QuicCtrl { 8 | ctrl: Connection, 9 | } 10 | 11 | impl QuicCtrl { 12 | pub fn new(conn: Connection) -> Self { 13 | Self { ctrl: conn } 14 | } 15 | } 16 | 17 | #[async_trait] 18 | impl AppStream for QuicCtrl { 19 | type InnerStream = BidirectionalStream; 20 | 21 | #[instrument(skip_all)] 22 | async fn open_stream(&mut self) -> Result, KvError> { 23 | let stream = self.ctrl.open_bidirectional_stream().await?; 24 | Ok(ProstClientStream::new(stream)) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/network/multiplex/yamux_mplex.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use futures::{future, Future, TryStreamExt}; 3 | use std::marker::PhantomData; 4 | use tokio::io::{AsyncRead, AsyncWrite}; 5 | use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt}; 6 | use tracing::instrument; 7 | use yamux::{Config, Connection, ConnectionError, Control, Mode, WindowUpdateMode}; 8 | 9 | use crate::{KvError, ProstClientStream}; 10 | 11 | use super::AppStream; 12 | 13 | /// Yamux 控制结构 14 | pub struct YamuxCtrl { 15 | /// yamux control,用于创建新的 stream 16 | ctrl: Control, 17 | _conn: PhantomData, 18 | } 19 | 20 | impl YamuxCtrl 21 | where 22 | S: AsyncRead + AsyncWrite + Unpin + Send + 'static, 23 | { 24 | /// 创建 yamux 客户端 25 | pub fn new_client(stream: S, config: Option) -> Self { 26 | Self::new(stream, config, true, |_stream| future::ready(Ok(()))) 27 | } 28 | 29 | /// 创建 yamux 服务端,服务端我们需要具体处理 stream 30 | pub fn new_server(stream: S, config: Option, f: F) -> Self 31 | where 32 | F: FnMut(yamux::Stream) -> Fut, 33 | F: Send + 'static, 34 | Fut: Future> + Send + 'static, 35 | { 36 | Self::new(stream, config, false, f) 37 | } 38 | 39 | #[instrument(name = "yamux_ctrl_new", skip_all)] 40 | // 创建 YamuxCtrl 41 | fn new(stream: S, config: Option, is_client: bool, f: F) -> Self 42 | where 43 | F: FnMut(yamux::Stream) -> Fut, 44 | F: Send + 'static, 45 | Fut: Future> + Send + 'static, 46 | { 47 | let mode = if is_client { 48 | Mode::Client 49 | } else { 50 | Mode::Server 51 | }; 52 | 53 | // 创建 config 54 | let mut config = config.unwrap_or_default(); 55 | config.set_window_update_mode(WindowUpdateMode::OnRead); 56 | 57 | // 创建 config,yamux::Stream 使用的是 futures 的 trait 所以需要 compat() 到 tokio 的 trait 58 | let conn = Connection::new(stream.compat(), config, mode); 59 | 60 | // 创建 yamux ctrl 61 | let ctrl = conn.control(); 62 | 63 | // pull 所有 stream 下的数据 64 | tokio::spawn(yamux::into_stream(conn).try_for_each_concurrent(None, f)); 65 | 66 | Self { 67 | ctrl, 68 | _conn: PhantomData::default(), 69 | } 70 | } 71 | } 72 | 73 | #[async_trait] 74 | impl AppStream for YamuxCtrl 75 | where 76 | S: AsyncRead + AsyncWrite + Unpin + Send + 'static, 77 | { 78 | type InnerStream = Compat; 79 | 80 | #[instrument(skip_all)] 81 | async fn open_stream(&mut self) -> Result, KvError> { 82 | let stream = self.ctrl.open_stream().await?; 83 | Ok(ProstClientStream::new(stream.compat())) 84 | } 85 | } 86 | 87 | #[cfg(test)] 88 | mod tests { 89 | use std::net::SocketAddr; 90 | 91 | use super::*; 92 | use crate::{ 93 | assert_res_ok, 94 | network::tls::tls_utils::{tls_acceptor, tls_connector}, 95 | utils::DummyStream, 96 | CommandRequest, KvError, MemTable, ProstServerStream, Service, ServiceInner, Storage, 97 | TlsServerAcceptor, 98 | }; 99 | use anyhow::Result; 100 | use tokio::net::{TcpListener, TcpStream}; 101 | use tokio_rustls::server; 102 | use tracing::warn; 103 | 104 | pub async fn start_server_with( 105 | addr: &str, 106 | tls: TlsServerAcceptor, 107 | store: Store, 108 | f: impl Fn(server::TlsStream, Service) + Send + Sync + 'static, 109 | ) -> Result 110 | where 111 | Store: Storage, 112 | Service: From>, 113 | { 114 | let listener = TcpListener::bind(addr).await.unwrap(); 115 | let addr = listener.local_addr().unwrap(); 116 | let service: Service = ServiceInner::new(store).into(); 117 | 118 | tokio::spawn(async move { 119 | loop { 120 | match listener.accept().await { 121 | Ok((stream, _addr)) => match tls.accept(stream).await { 122 | Ok(stream) => f(stream, service.clone()), 123 | Err(e) => warn!("Failed to process TLS: {:?}", e), 124 | }, 125 | Err(e) => warn!("Failed to process TCP: {:?}", e), 126 | } 127 | } 128 | }); 129 | 130 | Ok(addr) 131 | } 132 | 133 | /// 创建 ymaux server 134 | pub async fn start_yamux_server( 135 | addr: &str, 136 | tls: TlsServerAcceptor, 137 | store: Store, 138 | ) -> Result 139 | where 140 | Store: Storage, 141 | Service: From>, 142 | { 143 | let f = |stream, service: Service| { 144 | YamuxCtrl::new_server(stream, None, move |s| { 145 | let svc = service.clone(); 146 | async move { 147 | let stream = ProstServerStream::new(s.compat(), svc); 148 | stream.process().await.unwrap(); 149 | Ok(()) 150 | } 151 | }); 152 | }; 153 | start_server_with(addr, tls, store, f).await 154 | } 155 | 156 | #[tokio::test] 157 | async fn yamux_ctrl_creation_should_work() -> Result<()> { 158 | let s = DummyStream::default(); 159 | let mut ctrl = YamuxCtrl::new_client(s, None); 160 | let stream = ctrl.open_stream().await; 161 | 162 | assert!(stream.is_ok()); 163 | Ok(()) 164 | } 165 | 166 | #[tokio::test] 167 | async fn yamux_ctrl_client_server_should_work() -> Result<()> { 168 | // 创建使用了 TLS 的 yamux server 169 | let acceptor = tls_acceptor(false)?; 170 | let addr = start_yamux_server("127.0.0.1:0", acceptor, MemTable::new()).await?; 171 | 172 | let connector = tls_connector(false)?; 173 | let stream = TcpStream::connect(addr).await?; 174 | let stream = connector.connect(stream).await?; 175 | // 创建使用了 TLS 的 yamux client 176 | let mut ctrl = YamuxCtrl::new_client(stream, None); 177 | 178 | // 从 client ctrl 中打开一个新的 yamux stream 179 | let mut stream = ctrl.open_stream().await?; 180 | 181 | let cmd = CommandRequest::new_hset("t1", "k1", "v1".into()); 182 | stream.execute_unary(&cmd).await.unwrap(); 183 | 184 | let cmd = CommandRequest::new_hget("t1", "k1"); 185 | let res = stream.execute_unary(&cmd).await.unwrap(); 186 | assert_res_ok(&res, &["v1".into()], &[]); 187 | 188 | Ok(()) 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /src/network/stream.rs: -------------------------------------------------------------------------------- 1 | use bytes::BytesMut; 2 | use futures::{ready, FutureExt, Sink, Stream}; 3 | use std::{ 4 | marker::PhantomData, 5 | pin::Pin, 6 | task::{Context, Poll}, 7 | }; 8 | use tokio::io::{AsyncRead, AsyncWrite}; 9 | 10 | use crate::{read_frame, FrameCoder, KvError}; 11 | 12 | /// 处理 KV server prost frame 的 stream 13 | pub struct ProstStream { 14 | // innner stream 15 | stream: S, 16 | // 写缓存 17 | wbuf: BytesMut, 18 | // 写入了多少字节 19 | written: usize, 20 | // 读缓存 21 | rbuf: BytesMut, 22 | 23 | // 类型占位符 24 | _in: PhantomData, 25 | _out: PhantomData, 26 | } 27 | 28 | impl Stream for ProstStream 29 | where 30 | S: AsyncRead + AsyncWrite + Unpin + Send, 31 | In: Unpin + Send + FrameCoder, 32 | Out: Unpin + Send, 33 | { 34 | /// 当调用 next() 时,得到 Result 35 | type Item = Result; 36 | 37 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 38 | // 上一次调用结束后 rbuf 应该为空 39 | assert!(self.rbuf.is_empty()); 40 | 41 | // 从 rbuf 中分离出 rest(摆脱对 self 的引用) 42 | let mut rest = self.rbuf.split_off(0); 43 | 44 | // 使用 read_frame 来获取数据 45 | let fut = read_frame(&mut self.stream, &mut rest); 46 | ready!(Box::pin(fut).poll_unpin(cx))?; 47 | 48 | // 拿到一个 frame 的数据,把 buffer 合并回去 49 | self.rbuf.unsplit(rest); 50 | 51 | // 调用 decode_frame 获取解包后的数据 52 | Poll::Ready(Some(In::decode_frame(&mut self.rbuf))) 53 | } 54 | } 55 | 56 | /// 当调用 send() 时,会把 Out 发出去 57 | impl Sink<&Out> for ProstStream 58 | where 59 | S: AsyncRead + AsyncWrite + Unpin, 60 | In: Unpin + Send, 61 | Out: Unpin + Send + FrameCoder, 62 | { 63 | /// 如果发送出错,会返回 KvError 64 | type Error = KvError; 65 | 66 | fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { 67 | Poll::Ready(Ok(())) 68 | } 69 | 70 | fn start_send(self: Pin<&mut Self>, item: &Out) -> Result<(), Self::Error> { 71 | let this = self.get_mut(); 72 | item.encode_frame(&mut this.wbuf)?; 73 | 74 | Ok(()) 75 | } 76 | 77 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 78 | let this = self.get_mut(); 79 | 80 | // 循环写入 stream 中 81 | while this.written != this.wbuf.len() { 82 | let n = ready!(Pin::new(&mut this.stream).poll_write(cx, &this.wbuf[this.written..]))?; 83 | this.written += n; 84 | } 85 | 86 | // 清除 wbuf 87 | this.wbuf.clear(); 88 | this.written = 0; 89 | 90 | // 调用 stream 的 pull_flush 确保写入 91 | ready!(Pin::new(&mut this.stream).poll_flush(cx)?); 92 | Poll::Ready(Ok(())) 93 | } 94 | 95 | fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 96 | // 调用 stream 的 pull_flush 确保写入 97 | ready!(self.as_mut().poll_flush(cx))?; 98 | 99 | // 调用 stream 的 pull_shutdown 确保 stream 关闭 100 | ready!(Pin::new(&mut self.stream).poll_shutdown(cx))?; 101 | Poll::Ready(Ok(())) 102 | } 103 | } 104 | 105 | // 一般来说,如果我们的 Stream 是 Unpin,最好实现一下 106 | // Unpin 不像 Send/Sync 会自动实现 107 | impl Unpin for ProstStream where S: Unpin {} 108 | 109 | impl ProstStream 110 | where 111 | S: AsyncRead + AsyncWrite + Send + Unpin, 112 | { 113 | /// 创建一个 ProstStream 114 | pub fn new(stream: S) -> Self { 115 | Self { 116 | stream, 117 | written: 0, 118 | wbuf: BytesMut::new(), 119 | rbuf: BytesMut::new(), 120 | _in: PhantomData::default(), 121 | _out: PhantomData::default(), 122 | } 123 | } 124 | } 125 | 126 | #[cfg(test)] 127 | mod tests { 128 | use super::*; 129 | use crate::{utils::DummyStream, CommandRequest}; 130 | use anyhow::Result; 131 | use futures::prelude::*; 132 | 133 | #[tokio::test] 134 | async fn prost_stream_should_work() -> Result<()> { 135 | let buf = BytesMut::new(); 136 | let stream = DummyStream { buf }; 137 | 138 | // 创建 ProstStream 139 | let mut stream = ProstStream::<_, CommandRequest, CommandRequest>::new(stream); 140 | let cmd = CommandRequest::new_hdel("t1", "k1"); 141 | 142 | // 使用 ProstStream 发送数据 143 | stream.send(&cmd).await?; 144 | 145 | // 使用 ProstStream 接收数据 146 | if let Some(Ok(s)) = stream.next().await { 147 | assert_eq!(s, cmd); 148 | } else { 149 | unreachable!(); 150 | } 151 | Ok(()) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /src/network/stream_result.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | convert::TryInto, 3 | ops::{Deref, DerefMut}, 4 | pin::Pin, 5 | }; 6 | 7 | use futures::{Stream, StreamExt}; 8 | 9 | use crate::{CommandResponse, KvError}; 10 | 11 | /// 创建时之间取得 subscription id,并使用 Deref/DerefMut 使其用起来和 Stream 一致 12 | pub struct StreamResult { 13 | pub id: u32, 14 | inner: Pin> + Send>>, 15 | } 16 | 17 | impl StreamResult { 18 | pub async fn new(mut stream: T) -> Result 19 | where 20 | T: Stream> + Send + Unpin + 'static, 21 | { 22 | let id = match stream.next().await { 23 | Some(Ok(CommandResponse { 24 | status: 200, 25 | values: v, 26 | .. 27 | })) => { 28 | if v.is_empty() { 29 | return Err(KvError::Internal("Invalid stream".into())); 30 | } 31 | let id: i64 = (&v[0]).try_into().unwrap(); 32 | Ok(id as u32) 33 | } 34 | _ => Err(KvError::Internal("Invalid stream".into())), 35 | }; 36 | 37 | Ok(StreamResult { 38 | inner: Box::pin(stream), 39 | id: id?, 40 | }) 41 | } 42 | } 43 | 44 | impl Deref for StreamResult { 45 | type Target = Pin> + Send>>; 46 | 47 | fn deref(&self) -> &Self::Target { 48 | &self.inner 49 | } 50 | } 51 | 52 | impl DerefMut for StreamResult { 53 | fn deref_mut(&mut self) -> &mut Self::Target { 54 | &mut self.inner 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/network/tls.rs: -------------------------------------------------------------------------------- 1 | use std::io::Cursor; 2 | use std::sync::Arc; 3 | 4 | use tokio::io::{AsyncRead, AsyncWrite}; 5 | use tokio_rustls::rustls::{internal::pemfile, Certificate, ClientConfig, ServerConfig}; 6 | use tokio_rustls::rustls::{AllowAnyAuthenticatedClient, NoClientAuth, PrivateKey, RootCertStore}; 7 | use tokio_rustls::webpki::DNSNameRef; 8 | use tokio_rustls::TlsConnector; 9 | use tokio_rustls::{ 10 | client::TlsStream as ClientTlsStream, server::TlsStream as ServerTlsStream, TlsAcceptor, 11 | }; 12 | use tracing::instrument; 13 | 14 | use crate::KvError; 15 | 16 | /// KV Server 自己的 ALPN (Application-Layer Protocol Negotiation) 17 | const ALPN_KV: &str = "kv"; 18 | 19 | /// 存放 TLS ServerConfig 并提供方法 accept 把底层的协议转换成 TLS 20 | #[derive(Clone)] 21 | pub struct TlsServerAcceptor { 22 | inner: Arc, 23 | } 24 | 25 | /// 存放 TLS Client 并提供方法 connect 把底层的协议转换成 TLS 26 | #[derive(Clone)] 27 | pub struct TlsClientConnector { 28 | pub config: Arc, 29 | pub domain: Arc, 30 | } 31 | 32 | impl TlsClientConnector { 33 | /// 加载 client cert / CA cert,生成 ClientConfig 34 | #[instrument(name = "tls_connector_new", skip_all)] 35 | pub fn new( 36 | domain: impl Into + std::fmt::Debug, 37 | identity: Option<(&str, &str)>, 38 | server_ca: Option<&str>, 39 | ) -> Result { 40 | let mut config = ClientConfig::new(); 41 | 42 | // 如果有客户端证书,加载之 43 | if let Some((cert, key)) = identity { 44 | let certs = load_certs(cert)?; 45 | let key = load_key(key)?; 46 | config.set_single_client_cert(certs, key)?; 47 | } 48 | 49 | // 如果有签署服务器的 CA 证书,则加载它,这样服务器证书不在根证书链 50 | // 但是这个 CA 证书能验证它,也可以 51 | if let Some(cert) = server_ca { 52 | let mut buf = Cursor::new(cert); 53 | config.root_store.add_pem_file(&mut buf).unwrap(); 54 | } else { 55 | // 加载本地信任的根证书链 56 | config.root_store = match rustls_native_certs::load_native_certs() { 57 | Ok(store) | Err((Some(store), _)) => store, 58 | Err((None, error)) => return Err(error.into()), 59 | }; 60 | } 61 | 62 | Ok(Self { 63 | config: Arc::new(config), 64 | domain: Arc::new(domain.into()), 65 | }) 66 | } 67 | 68 | #[instrument(name = "tls_client_connect", skip_all)] 69 | /// 触发 TLS 协议,把底层的 stream 转换成 TLS stream 70 | pub async fn connect(&self, stream: S) -> Result, KvError> 71 | where 72 | S: AsyncRead + AsyncWrite + Unpin + Send, 73 | { 74 | let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str()) 75 | .map_err(|_| KvError::Internal("Invalid DNS name".into()))?; 76 | 77 | let stream = TlsConnector::from(self.config.clone()) 78 | .connect(dns, stream) 79 | .await?; 80 | 81 | Ok(stream) 82 | } 83 | } 84 | 85 | impl TlsServerAcceptor { 86 | /// 加载 server cert / CA cert,生成 ServerConfig 87 | #[instrument(name = "tls_acceptor_new", skip_all)] 88 | pub fn new(cert: &str, key: &str, client_ca: Option<&str>) -> Result { 89 | let certs = load_certs(cert)?; 90 | let key = load_key(key)?; 91 | 92 | let mut config = match client_ca { 93 | None => ServerConfig::new(NoClientAuth::new()), 94 | Some(cert) => { 95 | // 如果客户端证书是某个 CA 证书签发的,则把这个 CA 证书加载到信任链中 96 | let mut cert = Cursor::new(cert); 97 | let mut client_root_cert_store = RootCertStore::empty(); 98 | client_root_cert_store 99 | .add_pem_file(&mut cert) 100 | .map_err(|_| KvError::CertifcateParseError("CA", "cert"))?; 101 | 102 | let client_auth = AllowAnyAuthenticatedClient::new(client_root_cert_store); 103 | ServerConfig::new(client_auth) 104 | } 105 | }; 106 | 107 | // 加载服务器证书 108 | config 109 | .set_single_cert(certs, key) 110 | .map_err(|_| KvError::CertifcateParseError("server", "cert"))?; 111 | config.set_protocols(&[Vec::from(ALPN_KV)]); 112 | 113 | Ok(Self { 114 | inner: Arc::new(config), 115 | }) 116 | } 117 | 118 | #[instrument(name = "tls_server_accept", skip_all)] 119 | /// 触发 TLS 协议,把底层的 stream 转换成 TLS stream 120 | pub async fn accept(&self, stream: S) -> Result, KvError> 121 | where 122 | S: AsyncRead + AsyncWrite + Unpin + Send, 123 | { 124 | let acceptor = TlsAcceptor::from(self.inner.clone()); 125 | Ok(acceptor.accept(stream).await?) 126 | } 127 | } 128 | 129 | fn load_certs(cert: &str) -> Result, KvError> { 130 | let mut cert = Cursor::new(cert); 131 | pemfile::certs(&mut cert).map_err(|_| KvError::CertifcateParseError("server", "cert")) 132 | } 133 | 134 | fn load_key(key: &str) -> Result { 135 | let mut cursor = Cursor::new(key); 136 | 137 | // 先尝试用 PKCS8 加载私钥 138 | if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) { 139 | if !keys.is_empty() { 140 | return Ok(keys.remove(0)); 141 | } 142 | } 143 | 144 | // 再尝试加载 RSA key 145 | cursor.set_position(0); 146 | if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) { 147 | if !keys.is_empty() { 148 | return Ok(keys.remove(0)); 149 | } 150 | } 151 | 152 | // 不支持的私钥类型 153 | Err(KvError::CertifcateParseError("private", "key")) 154 | } 155 | 156 | #[cfg(test)] 157 | pub mod tls_utils { 158 | use crate::{KvError, TlsClientConnector, TlsServerAcceptor}; 159 | 160 | const CA_CERT: &str = include_str!("../../fixtures/ca.cert"); 161 | const CLIENT_CERT: &str = include_str!("../../fixtures/client.cert"); 162 | const CLIENT_KEY: &str = include_str!("../../fixtures/client.key"); 163 | const SERVER_CERT: &str = include_str!("../../fixtures/server.cert"); 164 | const SERVER_KEY: &str = include_str!("../../fixtures/server.key"); 165 | 166 | pub fn tls_connector(client_cert: bool) -> Result { 167 | let ca = Some(CA_CERT); 168 | let client_identity = Some((CLIENT_CERT, CLIENT_KEY)); 169 | 170 | match client_cert { 171 | false => TlsClientConnector::new("kvserver.acme.inc", None, ca), 172 | true => TlsClientConnector::new("kvserver.acme.inc", client_identity, ca), 173 | } 174 | } 175 | 176 | pub fn tls_acceptor(client_cert: bool) -> Result { 177 | let ca = Some(CA_CERT); 178 | match client_cert { 179 | true => TlsServerAcceptor::new(SERVER_CERT, SERVER_KEY, ca), 180 | false => TlsServerAcceptor::new(SERVER_CERT, SERVER_KEY, None), 181 | } 182 | } 183 | } 184 | 185 | #[cfg(test)] 186 | mod tests { 187 | use super::tls_utils::tls_acceptor; 188 | use crate::network::tls::tls_utils::tls_connector; 189 | use anyhow::Result; 190 | use std::net::SocketAddr; 191 | use std::sync::Arc; 192 | use tokio::{ 193 | io::{AsyncReadExt, AsyncWriteExt}, 194 | net::{TcpListener, TcpStream}, 195 | }; 196 | 197 | #[tokio::test] 198 | async fn tls_should_work() -> Result<()> { 199 | let addr = start_server(false).await?; 200 | let connector = tls_connector(false)?; 201 | let stream = TcpStream::connect(addr).await?; 202 | let mut stream = connector.connect(stream).await?; 203 | stream.write_all(b"hello world!").await?; 204 | let mut buf = [0; 12]; 205 | stream.read_exact(&mut buf).await?; 206 | assert_eq!(&buf, b"hello world!"); 207 | 208 | Ok(()) 209 | } 210 | 211 | #[tokio::test] 212 | async fn tls_with_client_cert_should_work() -> Result<()> { 213 | let addr = start_server(true).await?; 214 | let connector = tls_connector(true)?; 215 | let stream = TcpStream::connect(addr).await?; 216 | let mut stream = connector.connect(stream).await?; 217 | stream.write_all(b"hello world!").await?; 218 | let mut buf = [0; 12]; 219 | stream.read_exact(&mut buf).await?; 220 | assert_eq!(&buf, b"hello world!"); 221 | 222 | Ok(()) 223 | } 224 | 225 | #[tokio::test] 226 | async fn tls_with_bad_domain_should_not_work() -> Result<()> { 227 | let addr = start_server(false).await?; 228 | 229 | let mut connector = tls_connector(false)?; 230 | connector.domain = Arc::new("kvserver1.acme.inc".into()); 231 | let stream = TcpStream::connect(addr).await?; 232 | let result = connector.connect(stream).await; 233 | 234 | assert!(result.is_err()); 235 | 236 | Ok(()) 237 | } 238 | 239 | async fn start_server(client_cert: bool) -> Result { 240 | let acceptor = tls_acceptor(client_cert)?; 241 | 242 | let echo = TcpListener::bind("127.0.0.1:0").await.unwrap(); 243 | let addr = echo.local_addr().unwrap(); 244 | 245 | tokio::spawn(async move { 246 | let (stream, _) = echo.accept().await.unwrap(); 247 | let mut stream = acceptor.accept(stream).await.unwrap(); 248 | let mut buf = [0; 12]; 249 | stream.read_exact(&mut buf).await.unwrap(); 250 | stream.write_all(&buf).await.unwrap(); 251 | }); 252 | 253 | Ok(addr) 254 | } 255 | } 256 | -------------------------------------------------------------------------------- /src/pb/abi.rs: -------------------------------------------------------------------------------- 1 | /// 来自客户端的命令请求 2 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 3 | pub struct CommandRequest { 4 | #[prost( 5 | oneof = "command_request::RequestData", 6 | tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12" 7 | )] 8 | pub request_data: ::core::option::Option, 9 | } 10 | /// Nested message and enum types in `CommandRequest`. 11 | pub mod command_request { 12 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Oneof)] 13 | pub enum RequestData { 14 | #[prost(message, tag = "1")] 15 | Hget(super::Hget), 16 | #[prost(message, tag = "2")] 17 | Hgetall(super::Hgetall), 18 | #[prost(message, tag = "3")] 19 | Hmget(super::Hmget), 20 | #[prost(message, tag = "4")] 21 | Hset(super::Hset), 22 | #[prost(message, tag = "5")] 23 | Hmset(super::Hmset), 24 | #[prost(message, tag = "6")] 25 | Hdel(super::Hdel), 26 | #[prost(message, tag = "7")] 27 | Hmdel(super::Hmdel), 28 | #[prost(message, tag = "8")] 29 | Hexist(super::Hexist), 30 | #[prost(message, tag = "9")] 31 | Hmexist(super::Hmexist), 32 | #[prost(message, tag = "10")] 33 | Subscribe(super::Subscribe), 34 | #[prost(message, tag = "11")] 35 | Unsubscribe(super::Unsubscribe), 36 | #[prost(message, tag = "12")] 37 | Publish(super::Publish), 38 | } 39 | } 40 | /// 服务器的响应 41 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 42 | pub struct CommandResponse { 43 | /// 状态码;复用 HTTP 2xx/4xx/5xx 状态码 44 | #[prost(uint32, tag = "1")] 45 | pub status: u32, 46 | /// 如果不是 2xx,message 里包含详细的信息 47 | #[prost(string, tag = "2")] 48 | pub message: ::prost::alloc::string::String, 49 | /// 成功返回的 values 50 | #[prost(message, repeated, tag = "3")] 51 | pub values: ::prost::alloc::vec::Vec, 52 | /// 成功返回的 kv pairs 53 | #[prost(message, repeated, tag = "4")] 54 | pub pairs: ::prost::alloc::vec::Vec, 55 | } 56 | /// 从 table 中获取一个 key,返回 value 57 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 58 | pub struct Hget { 59 | #[prost(string, tag = "1")] 60 | pub table: ::prost::alloc::string::String, 61 | #[prost(string, tag = "2")] 62 | pub key: ::prost::alloc::string::String, 63 | } 64 | /// 从 table 中获取所有的 Kvpair 65 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 66 | pub struct Hgetall { 67 | #[prost(string, tag = "1")] 68 | pub table: ::prost::alloc::string::String, 69 | } 70 | /// 从 table 中获取一组 key,返回它们的 value 71 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 72 | pub struct Hmget { 73 | #[prost(string, tag = "1")] 74 | pub table: ::prost::alloc::string::String, 75 | #[prost(string, repeated, tag = "2")] 76 | pub keys: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, 77 | } 78 | /// 返回的值 79 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 80 | pub struct Value { 81 | #[prost(oneof = "value::Value", tags = "1, 2, 3, 4, 5")] 82 | pub value: ::core::option::Option, 83 | } 84 | /// Nested message and enum types in `Value`. 85 | pub mod value { 86 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Oneof)] 87 | pub enum Value { 88 | #[prost(string, tag = "1")] 89 | String(::prost::alloc::string::String), 90 | #[prost(bytes, tag = "2")] 91 | Binary(::prost::bytes::Bytes), 92 | #[prost(int64, tag = "3")] 93 | Integer(i64), 94 | #[prost(double, tag = "4")] 95 | Float(f64), 96 | #[prost(bool, tag = "5")] 97 | Bool(bool), 98 | } 99 | } 100 | /// 返回的 kvpair 101 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 102 | pub struct Kvpair { 103 | #[prost(string, tag = "1")] 104 | pub key: ::prost::alloc::string::String, 105 | #[prost(message, optional, tag = "2")] 106 | pub value: ::core::option::Option, 107 | } 108 | /// 往 table 里存一个 kvpair, 109 | /// 如果 table 不存在就创建这个 table 110 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 111 | pub struct Hset { 112 | #[prost(string, tag = "1")] 113 | pub table: ::prost::alloc::string::String, 114 | #[prost(message, optional, tag = "2")] 115 | pub pair: ::core::option::Option, 116 | } 117 | /// 往 table 中存一组 kvpair, 118 | /// 如果 table 不存在就创建这个 table 119 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 120 | pub struct Hmset { 121 | #[prost(string, tag = "1")] 122 | pub table: ::prost::alloc::string::String, 123 | #[prost(message, repeated, tag = "2")] 124 | pub pairs: ::prost::alloc::vec::Vec, 125 | } 126 | /// 从 table 中删除一个 key,返回它之前的值 127 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 128 | pub struct Hdel { 129 | #[prost(string, tag = "1")] 130 | pub table: ::prost::alloc::string::String, 131 | #[prost(string, tag = "2")] 132 | pub key: ::prost::alloc::string::String, 133 | } 134 | /// 从 table 中删除一组 key,返回它们之前的值 135 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 136 | pub struct Hmdel { 137 | #[prost(string, tag = "1")] 138 | pub table: ::prost::alloc::string::String, 139 | #[prost(string, repeated, tag = "2")] 140 | pub keys: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, 141 | } 142 | /// 查看 key 是否存在 143 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 144 | pub struct Hexist { 145 | #[prost(string, tag = "1")] 146 | pub table: ::prost::alloc::string::String, 147 | #[prost(string, tag = "2")] 148 | pub key: ::prost::alloc::string::String, 149 | } 150 | /// 查看一组 key 是否存在 151 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 152 | pub struct Hmexist { 153 | #[prost(string, tag = "1")] 154 | pub table: ::prost::alloc::string::String, 155 | #[prost(string, repeated, tag = "2")] 156 | pub keys: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, 157 | } 158 | /// subscribe 到某个主题,任何发布到这个主题的数据都会被收到 159 | /// 成功后,第一个返回的 CommandResponse,我们返回一个唯一的 subscription id 160 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 161 | pub struct Subscribe { 162 | #[prost(string, tag = "1")] 163 | pub topic: ::prost::alloc::string::String, 164 | } 165 | /// 取消对某个主题的订阅 166 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 167 | pub struct Unsubscribe { 168 | #[prost(string, tag = "1")] 169 | pub topic: ::prost::alloc::string::String, 170 | #[prost(uint32, tag = "2")] 171 | pub id: u32, 172 | } 173 | /// 发布数据到某个主题 174 | #[derive(PartialOrd, Clone, PartialEq, ::prost::Message)] 175 | pub struct Publish { 176 | #[prost(string, tag = "1")] 177 | pub topic: ::prost::alloc::string::String, 178 | #[prost(message, repeated, tag = "2")] 179 | pub data: ::prost::alloc::vec::Vec, 180 | } 181 | -------------------------------------------------------------------------------- /src/pb/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod abi; 2 | 3 | use std::convert::{TryFrom, TryInto}; 4 | 5 | use abi::{command_request::RequestData, *}; 6 | use bytes::Bytes; 7 | use http::StatusCode; 8 | use prost::Message; 9 | 10 | use crate::KvError; 11 | 12 | impl CommandRequest { 13 | pub fn new_hget(table: impl Into, key: impl Into) -> Self { 14 | Self { 15 | request_data: Some(RequestData::Hget(Hget { 16 | table: table.into(), 17 | key: key.into(), 18 | })), 19 | } 20 | } 21 | 22 | pub fn new_hgetall(table: impl Into) -> Self { 23 | Self { 24 | request_data: Some(RequestData::Hgetall(Hgetall { 25 | table: table.into(), 26 | })), 27 | } 28 | } 29 | 30 | pub fn new_hmget(table: impl Into, keys: Vec) -> Self { 31 | Self { 32 | request_data: Some(RequestData::Hmget(Hmget { 33 | table: table.into(), 34 | keys, 35 | })), 36 | } 37 | } 38 | 39 | pub fn new_hset(table: impl Into, key: impl Into, value: Value) -> Self { 40 | Self { 41 | request_data: Some(RequestData::Hset(Hset { 42 | table: table.into(), 43 | pair: Some(Kvpair::new(key, value)), 44 | })), 45 | } 46 | } 47 | 48 | pub fn new_hmset(table: impl Into, pairs: Vec) -> Self { 49 | Self { 50 | request_data: Some(RequestData::Hmset(Hmset { 51 | table: table.into(), 52 | pairs, 53 | })), 54 | } 55 | } 56 | 57 | pub fn new_hdel(table: impl Into, key: impl Into) -> Self { 58 | Self { 59 | request_data: Some(RequestData::Hdel(Hdel { 60 | table: table.into(), 61 | key: key.into(), 62 | })), 63 | } 64 | } 65 | 66 | pub fn new_hmdel(table: impl Into, keys: Vec) -> Self { 67 | Self { 68 | request_data: Some(RequestData::Hmdel(Hmdel { 69 | table: table.into(), 70 | keys, 71 | })), 72 | } 73 | } 74 | 75 | pub fn new_hexist(table: impl Into, key: impl Into) -> Self { 76 | Self { 77 | request_data: Some(RequestData::Hexist(Hexist { 78 | table: table.into(), 79 | key: key.into(), 80 | })), 81 | } 82 | } 83 | 84 | pub fn new_hmexist(table: impl Into, keys: Vec) -> Self { 85 | Self { 86 | request_data: Some(RequestData::Hmexist(Hmexist { 87 | table: table.into(), 88 | keys, 89 | })), 90 | } 91 | } 92 | 93 | pub fn new_subscribe(name: impl Into) -> Self { 94 | Self { 95 | request_data: Some(RequestData::Subscribe(Subscribe { topic: name.into() })), 96 | } 97 | } 98 | 99 | pub fn new_unsubscribe(name: impl Into, id: u32) -> Self { 100 | Self { 101 | request_data: Some(RequestData::Unsubscribe(Unsubscribe { 102 | topic: name.into(), 103 | id, 104 | })), 105 | } 106 | } 107 | 108 | pub fn new_publish(name: impl Into, data: Vec) -> Self { 109 | Self { 110 | request_data: Some(RequestData::Publish(Publish { 111 | topic: name.into(), 112 | data, 113 | })), 114 | } 115 | } 116 | 117 | /// 转换成 string 做错误处理 118 | pub fn format(&self) -> String { 119 | format!("{:?}", self) 120 | } 121 | } 122 | 123 | impl CommandResponse { 124 | pub fn ok() -> Self { 125 | CommandResponse { 126 | status: StatusCode::OK.as_u16() as _, 127 | ..Default::default() 128 | } 129 | } 130 | 131 | pub fn internal_error(msg: String) -> Self { 132 | CommandResponse { 133 | status: StatusCode::INTERNAL_SERVER_ERROR.as_u16() as _, 134 | message: msg, 135 | ..Default::default() 136 | } 137 | } 138 | 139 | /// 转换成 string 做错误处理 140 | pub fn format(&self) -> String { 141 | format!("{:?}", self) 142 | } 143 | } 144 | 145 | impl Value { 146 | /// 转换成 string 做错误处理 147 | pub fn format(&self) -> String { 148 | format!("{:?}", self) 149 | } 150 | } 151 | 152 | impl Kvpair { 153 | /// 创建一个新的 kv pair 154 | pub fn new(key: impl Into, value: Value) -> Self { 155 | Self { 156 | key: key.into(), 157 | value: Some(value), 158 | } 159 | } 160 | } 161 | 162 | /// 从 String 转换成 Value 163 | impl From for Value { 164 | fn from(s: String) -> Self { 165 | Self { 166 | value: Some(value::Value::String(s)), 167 | } 168 | } 169 | } 170 | 171 | /// 从 &str 转换成 Value 172 | impl From<&str> for Value { 173 | fn from(s: &str) -> Self { 174 | Self { 175 | value: Some(value::Value::String(s.into())), 176 | } 177 | } 178 | } 179 | 180 | /// 从 i64转换成 Value 181 | impl From for Value { 182 | fn from(i: i64) -> Self { 183 | Self { 184 | value: Some(value::Value::Integer(i)), 185 | } 186 | } 187 | } 188 | 189 | impl From<&[u8; N]> for Value { 190 | fn from(buf: &[u8; N]) -> Self { 191 | Bytes::copy_from_slice(&buf[..]).into() 192 | } 193 | } 194 | 195 | impl From for Value { 196 | fn from(buf: Bytes) -> Self { 197 | Self { 198 | value: Some(value::Value::Binary(buf)), 199 | } 200 | } 201 | } 202 | 203 | /// 从 Value 转换成 CommandResponse 204 | impl From for CommandResponse { 205 | fn from(v: Value) -> Self { 206 | Self { 207 | status: StatusCode::OK.as_u16() as _, 208 | values: vec![v], 209 | ..Default::default() 210 | } 211 | } 212 | } 213 | 214 | /// 从 Vec 转换成 CommandResponse 215 | impl From> for CommandResponse { 216 | fn from(v: Vec) -> Self { 217 | Self { 218 | status: StatusCode::OK.as_u16() as _, 219 | pairs: v, 220 | ..Default::default() 221 | } 222 | } 223 | } 224 | 225 | /// 从 KvError 转换成 CommandResponse 226 | impl From for CommandResponse { 227 | fn from(e: KvError) -> Self { 228 | let mut result = Self { 229 | status: StatusCode::INTERNAL_SERVER_ERROR.as_u16() as _, 230 | message: e.to_string(), 231 | values: vec![], 232 | pairs: vec![], 233 | }; 234 | 235 | match e { 236 | KvError::NotFound(_) => result.status = StatusCode::NOT_FOUND.as_u16() as _, 237 | KvError::InvalidCommand(_) => result.status = StatusCode::BAD_REQUEST.as_u16() as _, 238 | _ => {} 239 | } 240 | 241 | result 242 | } 243 | } 244 | 245 | impl From> for CommandResponse { 246 | fn from(v: Vec) -> Self { 247 | Self { 248 | status: StatusCode::OK.as_u16() as _, 249 | values: v, 250 | ..Default::default() 251 | } 252 | } 253 | } 254 | 255 | impl From for Value { 256 | fn from(b: bool) -> Self { 257 | Self { 258 | value: Some(value::Value::Bool(b)), 259 | } 260 | } 261 | } 262 | 263 | impl From for Value { 264 | fn from(f: f64) -> Self { 265 | Self { 266 | value: Some(value::Value::Float(f)), 267 | } 268 | } 269 | } 270 | 271 | impl TryFrom for i64 { 272 | type Error = KvError; 273 | 274 | fn try_from(v: Value) -> Result { 275 | match v.value { 276 | Some(value::Value::Integer(i)) => Ok(i), 277 | _ => Err(KvError::ConvertError(v.format(), "Integer")), 278 | } 279 | } 280 | } 281 | 282 | impl TryFrom<&Value> for i64 { 283 | type Error = KvError; 284 | 285 | fn try_from(v: &Value) -> Result { 286 | match v.value { 287 | Some(value::Value::Integer(i)) => Ok(i), 288 | _ => Err(KvError::ConvertError(v.format(), "Integer")), 289 | } 290 | } 291 | } 292 | 293 | impl TryFrom for f64 { 294 | type Error = KvError; 295 | 296 | fn try_from(v: Value) -> Result { 297 | match v.value { 298 | Some(value::Value::Float(f)) => Ok(f), 299 | _ => Err(KvError::ConvertError(v.format(), "Float")), 300 | } 301 | } 302 | } 303 | 304 | impl TryFrom for Bytes { 305 | type Error = KvError; 306 | 307 | fn try_from(v: Value) -> Result { 308 | match v.value { 309 | Some(value::Value::Binary(b)) => Ok(b), 310 | _ => Err(KvError::ConvertError(v.format(), "Binary")), 311 | } 312 | } 313 | } 314 | 315 | impl TryFrom for bool { 316 | type Error = KvError; 317 | 318 | fn try_from(v: Value) -> Result { 319 | match v.value { 320 | Some(value::Value::Bool(b)) => Ok(b), 321 | _ => Err(KvError::ConvertError(v.format(), "Boolean")), 322 | } 323 | } 324 | } 325 | 326 | impl TryFrom for Vec { 327 | type Error = KvError; 328 | fn try_from(v: Value) -> Result { 329 | let mut buf = Vec::with_capacity(v.encoded_len()); 330 | v.encode(&mut buf)?; 331 | Ok(buf) 332 | } 333 | } 334 | 335 | impl TryFrom<&[u8]> for Value { 336 | type Error = KvError; 337 | 338 | fn try_from(data: &[u8]) -> Result { 339 | let msg = Value::decode(data)?; 340 | Ok(msg) 341 | } 342 | } 343 | 344 | impl TryFrom<&CommandResponse> for i64 { 345 | type Error = KvError; 346 | 347 | fn try_from(value: &CommandResponse) -> Result { 348 | if value.status != StatusCode::OK.as_u16() as u32 { 349 | return Err(KvError::ConvertError(value.format(), "CommandResponse")); 350 | } 351 | match value.values.get(0) { 352 | Some(v) => v.try_into(), 353 | None => Err(KvError::ConvertError(value.format(), "CommandResponse")), 354 | } 355 | } 356 | } 357 | -------------------------------------------------------------------------------- /src/server.rs: -------------------------------------------------------------------------------- 1 | use std::{env, str::FromStr}; 2 | 3 | use anyhow::Result; 4 | use simple_kv::{start_server_with_config, RotationConfig, ServerConfig}; 5 | use tokio::fs; 6 | use tracing::span; 7 | use tracing_subscriber::{ 8 | filter, 9 | fmt::{self, format}, 10 | layer::SubscriberExt, 11 | prelude::*, 12 | EnvFilter, 13 | }; 14 | 15 | #[tokio::main] 16 | async fn main() -> Result<()> { 17 | let config = match env::var("KV_SERVER_CONFIG") { 18 | Ok(path) => fs::read_to_string(&path).await?, 19 | Err(_) => include_str!("../fixtures/quic_server.conf").to_string(), 20 | }; 21 | let config: ServerConfig = toml::from_str(&config)?; 22 | let log = &config.log; 23 | 24 | env::set_var("RUST_LOG", &log.log_level); 25 | 26 | let stdout_log = fmt::layer().compact(); 27 | 28 | let tracer = opentelemetry_jaeger::new_pipeline() 29 | .with_service_name("kv-server") 30 | .install_simple()?; 31 | let opentelemetry = tracing_opentelemetry::layer().with_tracer(tracer); 32 | 33 | let file_appender = match log.rotation { 34 | RotationConfig::Hourly => tracing_appender::rolling::hourly(&log.path, "server.log"), 35 | RotationConfig::Daily => tracing_appender::rolling::daily(&log.path, "server.log"), 36 | RotationConfig::Never => tracing_appender::rolling::never(&log.path, "server.log"), 37 | }; 38 | 39 | let (non_blocking, _guard1) = tracing_appender::non_blocking(file_appender); 40 | let fmt_layer = fmt::layer() 41 | .event_format(format().compact()) 42 | .with_writer(non_blocking); 43 | 44 | let level = filter::LevelFilter::from_str(&log.log_level)?; 45 | let jaeger_level = match log.enable_log_file { 46 | true => level, 47 | false => filter::LevelFilter::OFF, 48 | }; 49 | 50 | let log_file_level = match log.enable_log_file { 51 | true => level, 52 | false => filter::LevelFilter::OFF, 53 | }; 54 | 55 | tracing_subscriber::registry() 56 | .with(EnvFilter::from_default_env()) 57 | .with(stdout_log) 58 | .with(fmt_layer.with_filter(log_file_level)) 59 | .with(opentelemetry.with_filter(jaeger_level)) 60 | .init(); 61 | 62 | let root = span!(tracing::Level::INFO, "app_start"); 63 | let _enter = root.enter(); 64 | 65 | start_server_with_config(&config).await?; 66 | 67 | Ok(()) 68 | } 69 | -------------------------------------------------------------------------------- /src/service/command_service.rs: -------------------------------------------------------------------------------- 1 | use crate::*; 2 | 3 | impl CommandService for Hget { 4 | fn execute(self, store: &impl Storage) -> CommandResponse { 5 | match store.get(&self.table, &self.key) { 6 | Ok(Some(v)) => v.into(), 7 | Ok(None) => KvError::NotFound(format!("table {}, key {}", self.table, self.key)).into(), 8 | Err(e) => e.into(), 9 | } 10 | } 11 | } 12 | 13 | impl CommandService for Hmget { 14 | fn execute(self, store: &impl Storage) -> CommandResponse { 15 | self.keys 16 | .iter() 17 | .map(|key| match store.get(&self.table, key) { 18 | Ok(Some(v)) => v, 19 | _ => Value::default(), 20 | }) 21 | .collect::>() 22 | .into() 23 | } 24 | } 25 | 26 | impl CommandService for Hgetall { 27 | fn execute(self, store: &impl Storage) -> CommandResponse { 28 | match store.get_all(&self.table) { 29 | Ok(v) => v.into(), 30 | Err(e) => e.into(), 31 | } 32 | } 33 | } 34 | 35 | impl CommandService for Hset { 36 | fn execute(self, store: &impl Storage) -> CommandResponse { 37 | match self.pair { 38 | Some(v) => match store.set(&self.table, v.key, v.value.unwrap_or_default()) { 39 | Ok(Some(v)) => v.into(), 40 | Ok(None) => Value::default().into(), 41 | Err(e) => e.into(), 42 | }, 43 | None => KvError::InvalidCommand(format!("{:?}", self)).into(), 44 | } 45 | } 46 | } 47 | 48 | impl CommandService for Hmset { 49 | fn execute(self, store: &impl Storage) -> CommandResponse { 50 | let pairs = self.pairs; 51 | let table = self.table; 52 | pairs 53 | .into_iter() 54 | .map(|pair| { 55 | let result = store.set(&table, pair.key, pair.value.unwrap_or_default()); 56 | match result { 57 | Ok(Some(v)) => v, 58 | _ => Value::default(), 59 | } 60 | }) 61 | .collect::>() 62 | .into() 63 | } 64 | } 65 | 66 | impl CommandService for Hdel { 67 | fn execute(self, store: &impl Storage) -> CommandResponse { 68 | match store.del(&self.table, &self.key) { 69 | Ok(Some(v)) => v.into(), 70 | Ok(None) => Value::default().into(), 71 | Err(e) => e.into(), 72 | } 73 | } 74 | } 75 | 76 | impl CommandService for Hmdel { 77 | fn execute(self, store: &impl Storage) -> CommandResponse { 78 | self.keys 79 | .iter() 80 | .map(|key| match store.del(&self.table, key) { 81 | Ok(Some(v)) => v, 82 | _ => Value::default(), 83 | }) 84 | .collect::>() 85 | .into() 86 | } 87 | } 88 | 89 | impl CommandService for Hexist { 90 | fn execute(self, store: &impl Storage) -> CommandResponse { 91 | match store.contains(&self.table, &self.key) { 92 | Ok(v) => Value::from(v).into(), 93 | Err(e) => e.into(), 94 | } 95 | } 96 | } 97 | 98 | impl CommandService for Hmexist { 99 | fn execute(self, store: &impl Storage) -> CommandResponse { 100 | self.keys 101 | .iter() 102 | .map(|key| match store.contains(&self.table, key) { 103 | Ok(v) => v.into(), 104 | _ => Value::default(), 105 | }) 106 | .collect::>() 107 | .into() 108 | } 109 | } 110 | 111 | #[cfg(test)] 112 | mod tests { 113 | use super::*; 114 | 115 | #[test] 116 | fn hget_should_work() { 117 | let store = MemTable::new(); 118 | let cmd = CommandRequest::new_hset("score", "u1", 10.into()); 119 | dispatch(cmd, &store); 120 | let cmd = CommandRequest::new_hget("score", "u1"); 121 | let res = dispatch(cmd, &store); 122 | assert_res_ok(&res, &[10.into()], &[]); 123 | } 124 | 125 | #[test] 126 | fn hget_with_non_exist_key_should_return_404() { 127 | let store = MemTable::new(); 128 | let cmd = CommandRequest::new_hget("score", "u1"); 129 | let res = dispatch(cmd, &store); 130 | assert_res_error(&res, 404, "Not found"); 131 | } 132 | 133 | #[test] 134 | fn hmget_should_work() { 135 | let store = MemTable::new(); 136 | 137 | set_key_pairs( 138 | "user", 139 | vec![("u1", "Tyr"), ("u2", "Lindsey"), ("u3", "Rosie")], 140 | &store, 141 | ); 142 | 143 | let cmd = CommandRequest::new_hmget("user", vec!["u1".into(), "u4".into(), "u3".into()]); 144 | let res = dispatch(cmd, &store); 145 | let values = &["Tyr".into(), Value::default(), "Rosie".into()]; 146 | assert_res_ok(&res, values, &[]); 147 | } 148 | 149 | #[test] 150 | fn hgetall_should_work() { 151 | let store = MemTable::new(); 152 | 153 | set_key_pairs( 154 | "score", 155 | vec![("u1", 10), ("u2", 8), ("u3", 11), ("u1", 6)], 156 | &store, 157 | ); 158 | 159 | let cmd = CommandRequest::new_hgetall("score"); 160 | let res = dispatch(cmd, &store); 161 | let pairs = &[ 162 | Kvpair::new("u1", 6.into()), 163 | Kvpair::new("u2", 8.into()), 164 | Kvpair::new("u3", 11.into()), 165 | ]; 166 | assert_res_ok(&res, &[], pairs); 167 | } 168 | 169 | #[test] 170 | fn hset_should_work() { 171 | let store = MemTable::new(); 172 | let cmd = CommandRequest::new_hset("t1", "hello", "world".into()); 173 | let res = dispatch(cmd.clone(), &store); 174 | assert_res_ok(&res, &[Value::default()], &[]); 175 | 176 | let res = dispatch(cmd, &store); 177 | assert_res_ok(&res, &["world".into()], &[]); 178 | } 179 | 180 | #[test] 181 | fn hmset_should_work() { 182 | let store = MemTable::new(); 183 | set_key_pairs("t1", vec![("u1", "world")], &store); 184 | let pairs = vec![ 185 | Kvpair::new("u1", 10.1.into()), 186 | Kvpair::new("u2", 8.1.into()), 187 | ]; 188 | let cmd = CommandRequest::new_hmset("t1", pairs); 189 | let res = dispatch(cmd, &store); 190 | assert_res_ok(&res, &["world".into(), Value::default()], &[]); 191 | } 192 | 193 | #[test] 194 | fn hdel_should_work() { 195 | let store = MemTable::new(); 196 | set_key_pairs("t1", vec![("u1", "v1")], &store); 197 | let cmd = CommandRequest::new_hdel("t1", "u2"); 198 | let res = dispatch(cmd, &store); 199 | assert_res_ok(&res, &[Value::default()], &[]); 200 | 201 | let cmd = CommandRequest::new_hdel("t1", "u1"); 202 | let res = dispatch(cmd, &store); 203 | assert_res_ok(&res, &["v1".into()], &[]); 204 | } 205 | 206 | #[test] 207 | fn hmdel_should_work() { 208 | let store = MemTable::new(); 209 | set_key_pairs("t1", vec![("u1", "v1"), ("u2", "v2")], &store); 210 | 211 | let cmd = CommandRequest::new_hmdel("t1", vec!["u1".into(), "u3".into()]); 212 | let res = dispatch(cmd, &store); 213 | assert_res_ok(&res, &["v1".into(), Value::default()], &[]); 214 | } 215 | 216 | #[test] 217 | fn hexist_should_work() { 218 | let store = MemTable::new(); 219 | set_key_pairs("t1", vec![("u1", "v1")], &store); 220 | let cmd = CommandRequest::new_hexist("t1", "u2"); 221 | let res = dispatch(cmd, &store); 222 | assert_res_ok(&res, &[false.into()], &[]); 223 | 224 | let cmd = CommandRequest::new_hexist("t1", "u1"); 225 | let res = dispatch(cmd, &store); 226 | assert_res_ok(&res, &[true.into()], &[]); 227 | } 228 | 229 | #[test] 230 | fn hmexist_should_work() { 231 | let store = MemTable::new(); 232 | set_key_pairs("t1", vec![("u1", "v1"), ("u2", "v2")], &store); 233 | 234 | let cmd = CommandRequest::new_hmexist("t1", vec!["u1".into(), "u3".into()]); 235 | let res = dispatch(cmd, &store); 236 | assert_res_ok(&res, &[true.into(), false.into()], &[]); 237 | } 238 | 239 | fn set_key_pairs>(table: &str, pairs: Vec<(&str, T)>, store: &impl Storage) { 240 | pairs 241 | .into_iter() 242 | .map(|(k, v)| CommandRequest::new_hset(table, k, v.into())) 243 | .for_each(|cmd| { 244 | dispatch(cmd, store); 245 | }); 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /src/service/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | command_request::RequestData, CommandRequest, CommandResponse, KvError, MemTable, Storage, 3 | }; 4 | use futures::stream; 5 | use std::sync::Arc; 6 | use tracing::{debug, instrument}; 7 | 8 | mod command_service; 9 | mod topic; 10 | mod topic_service; 11 | 12 | pub use topic::{Broadcaster, Topic}; 13 | pub use topic_service::{StreamingResponse, TopicService}; 14 | 15 | /// 对 Command 的处理的抽象 16 | pub trait CommandService { 17 | /// 处理 Command,返回 Response 18 | fn execute(self, store: &impl Storage) -> CommandResponse; 19 | } 20 | 21 | /// 事件通知(不可变事件) 22 | pub trait Notify { 23 | fn notify(&self, arg: &Arg); 24 | } 25 | 26 | /// 事件通知(可变事件) 27 | pub trait NotifyMut { 28 | fn notify(&self, arg: &mut Arg); 29 | } 30 | 31 | impl Notify for Vec { 32 | #[inline] 33 | fn notify(&self, arg: &Arg) { 34 | for f in self { 35 | f(arg) 36 | } 37 | } 38 | } 39 | 40 | impl NotifyMut for Vec { 41 | #[inline] 42 | fn notify(&self, arg: &mut Arg) { 43 | for f in self { 44 | f(arg) 45 | } 46 | } 47 | } 48 | 49 | /// Service 数据结构 50 | pub struct Service { 51 | inner: Arc>, 52 | broadcaster: Arc, 53 | } 54 | 55 | impl Clone for Service { 56 | fn clone(&self) -> Self { 57 | Self { 58 | inner: Arc::clone(&self.inner), 59 | broadcaster: Arc::clone(&self.broadcaster), 60 | } 61 | } 62 | } 63 | 64 | /// Service 内部数据结构 65 | pub struct ServiceInner { 66 | store: Store, 67 | on_received: Vec, 68 | on_executed: Vec, 69 | on_before_send: Vec, 70 | on_after_send: Vec, 71 | } 72 | 73 | impl ServiceInner { 74 | pub fn new(store: Store) -> Self { 75 | Self { 76 | store, 77 | on_received: Vec::new(), 78 | on_executed: Vec::new(), 79 | on_before_send: Vec::new(), 80 | on_after_send: Vec::new(), 81 | } 82 | } 83 | 84 | pub fn fn_received(mut self, f: fn(&CommandRequest)) -> Self { 85 | self.on_received.push(f); 86 | self 87 | } 88 | 89 | pub fn fn_executed(mut self, f: fn(&CommandResponse)) -> Self { 90 | self.on_executed.push(f); 91 | self 92 | } 93 | 94 | pub fn fn_before_send(mut self, f: fn(&mut CommandResponse)) -> Self { 95 | self.on_before_send.push(f); 96 | self 97 | } 98 | 99 | pub fn fn_after_send(mut self, f: fn()) -> Self { 100 | self.on_after_send.push(f); 101 | self 102 | } 103 | } 104 | 105 | impl From> for Service { 106 | fn from(inner: ServiceInner) -> Self { 107 | Self { 108 | inner: Arc::new(inner), 109 | broadcaster: Default::default(), 110 | } 111 | } 112 | } 113 | 114 | impl Service { 115 | #[instrument(name = "service_execute", skip_all)] 116 | pub fn execute(&self, cmd: CommandRequest) -> StreamingResponse { 117 | debug!("Got request: {:?}", cmd); 118 | self.inner.on_received.notify(&cmd); 119 | let mut res = dispatch(cmd.clone(), &self.inner.store); 120 | 121 | if res == CommandResponse::default() { 122 | dispatch_stream(cmd, Arc::clone(&self.broadcaster)) 123 | } else { 124 | debug!("Executed response: {:?}", res); 125 | self.inner.on_executed.notify(&res); 126 | self.inner.on_before_send.notify(&mut res); 127 | if !self.inner.on_before_send.is_empty() { 128 | debug!("Modified response: {:?}", res); 129 | } 130 | 131 | Box::pin(stream::once(async { Arc::new(res) })) 132 | } 133 | } 134 | } 135 | 136 | /// 从 Request 中得到 Response,目前处理所有 HGET/HSET/HDEL/HEXIST 137 | pub fn dispatch(cmd: CommandRequest, store: &impl Storage) -> CommandResponse { 138 | match cmd.request_data { 139 | Some(RequestData::Hget(param)) => param.execute(store), 140 | Some(RequestData::Hgetall(param)) => param.execute(store), 141 | Some(RequestData::Hmget(param)) => param.execute(store), 142 | Some(RequestData::Hset(param)) => param.execute(store), 143 | Some(RequestData::Hmset(param)) => param.execute(store), 144 | Some(RequestData::Hdel(param)) => param.execute(store), 145 | Some(RequestData::Hmdel(param)) => param.execute(store), 146 | Some(RequestData::Hexist(param)) => param.execute(store), 147 | Some(RequestData::Hmexist(param)) => param.execute(store), 148 | None => KvError::InvalidCommand("Request has no data".into()).into(), 149 | // 处理不了的返回一个啥都不包括的 Response,这样后续可以用 dispatch_stream 处理 150 | _ => CommandResponse::default(), 151 | } 152 | } 153 | 154 | /// 从 Request 中得到 Response,目前处理所有 PUBLISH/SUBSCRIBE/UNSUBSCRIBE 155 | pub fn dispatch_stream(cmd: CommandRequest, topic: impl Topic) -> StreamingResponse { 156 | match cmd.request_data { 157 | Some(RequestData::Publish(param)) => param.execute(topic), 158 | Some(RequestData::Subscribe(param)) => param.execute(topic), 159 | Some(RequestData::Unsubscribe(param)) => param.execute(topic), 160 | // 如果走到这里,就是代码逻辑的问题,直接 crash 出来 161 | _ => unreachable!(), 162 | } 163 | } 164 | 165 | #[cfg(test)] 166 | mod tests { 167 | use http::StatusCode; 168 | use tokio_stream::StreamExt; 169 | use tracing::info; 170 | 171 | use super::*; 172 | use crate::{MemTable, Value}; 173 | 174 | #[tokio::test] 175 | async fn service_should_works() { 176 | // 我们需要一个 service 结构至少包含 Storage 177 | let service: Service = ServiceInner::new(MemTable::default()).into(); 178 | 179 | // service 可以运行在多线程环境下,它的 clone 应该是轻量级的 180 | let cloned = service.clone(); 181 | 182 | // 创建一个线程,在 table t1 中写入 k1, v1 183 | tokio::spawn(async move { 184 | let mut res = cloned.execute(CommandRequest::new_hset("t1", "k1", "v1".into())); 185 | let data = res.next().await.unwrap(); 186 | assert_res_ok(&data, &[Value::default()], &[]); 187 | }) 188 | .await 189 | .unwrap(); 190 | 191 | // 在当前线程下读取 table t1 的 k1,应该返回 v1 192 | let mut res = service.execute(CommandRequest::new_hget("t1", "k1")); 193 | let data = res.next().await.unwrap(); 194 | assert_res_ok(&data, &["v1".into()], &[]); 195 | } 196 | 197 | #[tokio::test] 198 | async fn event_registration_should_work() { 199 | fn b(cmd: &CommandRequest) { 200 | info!("Got {:?}", cmd); 201 | } 202 | fn c(res: &CommandResponse) { 203 | info!("{:?}", res); 204 | } 205 | fn d(res: &mut CommandResponse) { 206 | res.status = StatusCode::CREATED.as_u16() as _; 207 | } 208 | fn e() { 209 | info!("Data is sent"); 210 | } 211 | 212 | let service: Service = ServiceInner::new(MemTable::default()) 213 | .fn_received(|_: &CommandRequest| {}) 214 | .fn_received(b) 215 | .fn_executed(c) 216 | .fn_before_send(d) 217 | .fn_after_send(e) 218 | .into(); 219 | 220 | let mut res = service.execute(CommandRequest::new_hset("t1", "k1", "v1".into())); 221 | let data = res.next().await.unwrap(); 222 | assert_eq!(data.status, StatusCode::CREATED.as_u16() as u32); 223 | assert_eq!(data.message, ""); 224 | assert_eq!(data.values, vec![Value::default()]); 225 | } 226 | } 227 | 228 | #[cfg(test)] 229 | use crate::{Kvpair, Value}; 230 | 231 | // 测试成功返回的结果 232 | #[cfg(test)] 233 | pub fn assert_res_ok(res: &CommandResponse, values: &[Value], pairs: &[Kvpair]) { 234 | let mut sorted_pairs = res.pairs.clone(); 235 | sorted_pairs.sort_by(|a, b| a.partial_cmp(b).unwrap()); 236 | assert_eq!(res.status, 200); 237 | assert_eq!(res.message, ""); 238 | assert_eq!(res.values, values); 239 | assert_eq!(sorted_pairs, pairs); 240 | } 241 | 242 | // 测试失败返回的结果 243 | #[cfg(test)] 244 | pub fn assert_res_error(res: &CommandResponse, code: u32, msg: &str) { 245 | assert_eq!(res.status, code); 246 | assert!(res.message.contains(msg)); 247 | assert_eq!(res.values, &[]); 248 | assert_eq!(res.pairs, &[]); 249 | } 250 | -------------------------------------------------------------------------------- /src/service/topic.rs: -------------------------------------------------------------------------------- 1 | use dashmap::{DashMap, DashSet}; 2 | use std::sync::{ 3 | atomic::{AtomicU32, Ordering}, 4 | Arc, 5 | }; 6 | use tokio::sync::mpsc; 7 | use tracing::{debug, info, instrument, warn}; 8 | 9 | use crate::{CommandResponse, KvError, Value}; 10 | 11 | /// topic 里最大存放的数据 12 | const BROADCAST_CAPACITY: usize = 128; 13 | 14 | /// 下一个 subscription id 15 | static NEXT_ID: AtomicU32 = AtomicU32::new(1); 16 | 17 | /// 获取下一个 subscription id 18 | fn get_next_subscription_id() -> u32 { 19 | NEXT_ID.fetch_add(1, Ordering::Relaxed) 20 | } 21 | 22 | pub trait Topic: Send + Sync + 'static { 23 | /// 订阅某个主题 24 | fn subscribe(self, name: String) -> mpsc::Receiver>; 25 | /// 取消对主题的订阅 26 | fn unsubscribe(self, name: String, id: u32) -> Result; 27 | /// 往主题里发布一个数据 28 | fn publish(self, name: String, value: Arc); 29 | } 30 | 31 | /// 用于主题发布和订阅的数据结构 32 | #[derive(Default)] 33 | pub struct Broadcaster { 34 | /// 所有的主题列表 35 | topics: DashMap>, 36 | /// 所有的订阅列表 37 | subscriptions: DashMap>>, 38 | } 39 | 40 | impl Topic for Arc { 41 | #[instrument(name = "topic_subscribe", skip_all)] 42 | fn subscribe(self, name: String) -> mpsc::Receiver> { 43 | let id = { 44 | let entry = self.topics.entry(name).or_default(); 45 | let id = get_next_subscription_id(); 46 | entry.value().insert(id); 47 | id 48 | }; 49 | 50 | // 生成一个 mpsc channel 51 | let (tx, rx) = mpsc::channel(BROADCAST_CAPACITY); 52 | 53 | let v: Value = (id as i64).into(); 54 | 55 | // 立刻发送 subscription id 到 rx 56 | let tx1 = tx.clone(); 57 | tokio::spawn(async move { 58 | if let Err(e) = tx1.send(Arc::new(v.into())).await { 59 | // TODO: 这个很小概率发生,但目前我们没有善后 60 | warn!("Failed to send subscription id: {}. Error: {:?}", id, e); 61 | } 62 | }); 63 | 64 | // 把 tx 存入 subscription table 65 | self.subscriptions.insert(id, tx); 66 | debug!("Subscription {} is added", id); 67 | 68 | // 返回 rx 给网络处理的上下文 69 | rx 70 | } 71 | 72 | #[instrument(name = "topic_unsubscribe", skip_all)] 73 | fn unsubscribe(self, name: String, id: u32) -> Result { 74 | match self.remove_subscription(name, id) { 75 | Some(id) => Ok(id), 76 | None => Err(KvError::NotFound(format!("subscription {}", id))), 77 | } 78 | } 79 | 80 | #[instrument(name = "topic_publish", skip_all)] 81 | fn publish(self, name: String, value: Arc) { 82 | tokio::spawn(async move { 83 | let mut ids = vec![]; 84 | if let Some(topic) = self.topics.get(&name) { 85 | // 复制整个 topic 下所有的 subscription id 86 | // 这里我们每个 id 是 u32,如果一个 topic 下有 10k 订阅,复制的成本 87 | // 也就是 40k 堆内存(外加一些控制结构),所以效率不算差 88 | // 这也是为什么我们用 NEXT_ID 来控制 subscription id 的生成 89 | 90 | let subscriptions = topic.value().clone(); 91 | // 尽快释放锁 92 | drop(topic); 93 | 94 | // 循环发送 95 | for id in subscriptions.into_iter() { 96 | if let Some(tx) = self.subscriptions.get(&id) { 97 | if let Err(e) = tx.send(value.clone()).await { 98 | warn!("Publish to {} failed! error: {:?}", id, e); 99 | // client 中断连接 100 | ids.push(id); 101 | } 102 | } 103 | } 104 | } 105 | 106 | for id in ids { 107 | self.remove_subscription(name.clone(), id); 108 | } 109 | }); 110 | } 111 | } 112 | 113 | impl Broadcaster { 114 | pub fn remove_subscription(&self, name: String, id: u32) -> Option { 115 | if let Some(v) = self.topics.get_mut(&name) { 116 | // 在 topics 表里找到 topic 的 subscription id,删除 117 | v.remove(&id); 118 | 119 | // 如果这个 topic 为空,则也删除 topic 120 | if v.is_empty() { 121 | info!("Topic: {:?} is deleted", &name); 122 | drop(v); 123 | self.topics.remove(&name); 124 | } 125 | } 126 | 127 | debug!("Subscription {} is removed!", id); 128 | // 在 subscription 表中同样删除 129 | self.subscriptions.remove(&id).map(|(id, _)| id) 130 | } 131 | } 132 | 133 | #[cfg(test)] 134 | mod tests { 135 | use std::convert::TryInto; 136 | 137 | use tokio::sync::mpsc::Receiver; 138 | 139 | use crate::assert_res_ok; 140 | 141 | use super::*; 142 | 143 | #[tokio::test] 144 | async fn pub_sub_should_work() { 145 | let b = Arc::new(Broadcaster::default()); 146 | let lobby = "lobby".to_string(); 147 | 148 | // subscribe 149 | let mut stream1 = b.clone().subscribe(lobby.clone()); 150 | let mut stream2 = b.clone().subscribe(lobby.clone()); 151 | 152 | // publish 153 | let v: Value = "hello".into(); 154 | b.clone().publish(lobby.clone(), Arc::new(v.clone().into())); 155 | 156 | // subscribers 应该能收到 publish 的数据 157 | let id1 = get_id(&mut stream1).await; 158 | let id2 = get_id(&mut stream2).await; 159 | 160 | assert!(id1 != id2); 161 | 162 | let res1 = stream1.recv().await.unwrap(); 163 | let res2 = stream2.recv().await.unwrap(); 164 | 165 | assert_eq!(res1, res2); 166 | assert_res_ok(&res1, &[v.clone()], &[]); 167 | 168 | // 如果 subscriber 取消订阅,则收不到新数据 169 | let result = b.clone().unsubscribe(lobby.clone(), id1 as _).unwrap(); 170 | assert_eq!(result, id1 as _); 171 | 172 | // publish 173 | let v: Value = "world".into(); 174 | b.clone().publish(lobby.clone(), Arc::new(v.clone().into())); 175 | 176 | assert!(stream1.recv().await.is_none()); 177 | let res2 = stream2.recv().await.unwrap(); 178 | assert_res_ok(&res2, &[v.clone()], &[]); 179 | } 180 | 181 | pub async fn get_id(res: &mut Receiver>) -> u32 { 182 | let id: i64 = res.recv().await.unwrap().as_ref().try_into().unwrap(); 183 | id as u32 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /src/service/topic_service.rs: -------------------------------------------------------------------------------- 1 | use futures::{stream, Stream}; 2 | use std::{pin::Pin, sync::Arc}; 3 | use tokio_stream::wrappers::ReceiverStream; 4 | 5 | use crate::{CommandResponse, Publish, Subscribe, Topic, Unsubscribe}; 6 | 7 | pub type StreamingResponse = Pin> + Send>>; 8 | 9 | pub trait TopicService { 10 | /// 处理 Command,返回 Response 11 | fn execute(self, topic: impl Topic) -> StreamingResponse; 12 | } 13 | 14 | impl TopicService for Subscribe { 15 | fn execute(self, topic: impl Topic) -> StreamingResponse { 16 | let rx = topic.subscribe(self.topic); 17 | Box::pin(ReceiverStream::new(rx)) 18 | } 19 | } 20 | 21 | impl TopicService for Unsubscribe { 22 | fn execute(self, topic: impl Topic) -> StreamingResponse { 23 | let res = match topic.unsubscribe(self.topic, self.id) { 24 | Ok(_) => CommandResponse::ok(), 25 | Err(e) => e.into(), 26 | }; 27 | Box::pin(stream::once(async { Arc::new(res) })) 28 | } 29 | } 30 | 31 | impl TopicService for Publish { 32 | fn execute(self, topic: impl Topic) -> StreamingResponse { 33 | topic.publish(self.topic, Arc::new(self.data.into())); 34 | Box::pin(stream::once(async { Arc::new(CommandResponse::ok()) })) 35 | } 36 | } 37 | 38 | #[cfg(test)] 39 | mod tests { 40 | use super::*; 41 | use crate::{assert_res_error, assert_res_ok, dispatch_stream, Broadcaster, CommandRequest}; 42 | use futures::StreamExt; 43 | use std::{convert::TryInto, time::Duration}; 44 | use tokio::time; 45 | 46 | #[tokio::test] 47 | async fn dispatch_publish_should_work() { 48 | let topic = Arc::new(Broadcaster::default()); 49 | let cmd = CommandRequest::new_publish("lobby", vec!["hello".into()]); 50 | let mut res = dispatch_stream(cmd, topic); 51 | let data = res.next().await.unwrap(); 52 | assert_res_ok(&data, &[], &[]); 53 | } 54 | 55 | #[tokio::test] 56 | async fn dispatch_subscribe_should_work() { 57 | let topic = Arc::new(Broadcaster::default()); 58 | let cmd = CommandRequest::new_subscribe("lobby"); 59 | let mut res = dispatch_stream(cmd, topic); 60 | let id = get_id(&mut res).await; 61 | assert!(id > 0); 62 | } 63 | 64 | #[tokio::test] 65 | async fn dispatch_subscribe_abnormal_quit_should_be_removed_on_next_publish() { 66 | let topic = Arc::new(Broadcaster::default()); 67 | let id = { 68 | let cmd = CommandRequest::new_subscribe("lobby"); 69 | let mut res = dispatch_stream(cmd, topic.clone()); 70 | let id = get_id(&mut res).await; 71 | drop(res); 72 | id as u32 73 | }; 74 | 75 | // publish 时,这个 subscription 已经失效,所以会被删除 76 | let cmd = CommandRequest::new_publish("lobby", vec!["hello".into()]); 77 | dispatch_stream(cmd, topic.clone()); 78 | time::sleep(Duration::from_millis(10)).await; 79 | 80 | // 如果再尝试删除,应该返回 KvError 81 | let result = topic.unsubscribe("lobby".into(), id); 82 | assert!(result.is_err()); 83 | } 84 | 85 | #[tokio::test] 86 | async fn dispatch_unsubscribe_should_work() { 87 | let topic = Arc::new(Broadcaster::default()); 88 | let cmd = CommandRequest::new_subscribe("lobby"); 89 | let mut res = dispatch_stream(cmd, topic.clone()); 90 | let id = get_id(&mut res).await; 91 | 92 | let cmd = CommandRequest::new_unsubscribe("lobby", id as _); 93 | let mut res = dispatch_stream(cmd, topic); 94 | let data = res.next().await.unwrap(); 95 | 96 | assert_res_ok(&data, &[], &[]); 97 | } 98 | 99 | #[tokio::test] 100 | async fn dispatch_unsubscribe_random_id_should_error() { 101 | let topic = Arc::new(Broadcaster::default()); 102 | 103 | let cmd = CommandRequest::new_unsubscribe("lobby", 9527); 104 | let mut res = dispatch_stream(cmd, topic); 105 | let data = res.next().await.unwrap(); 106 | 107 | assert_res_error(&data, 404, "Not found: subscription 9527"); 108 | } 109 | 110 | pub async fn get_id(res: &mut StreamingResponse) -> u32 { 111 | let id: i64 = res.next().await.unwrap().as_ref().try_into().unwrap(); 112 | id as u32 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /src/storage/memory.rs: -------------------------------------------------------------------------------- 1 | use crate::{KvError, Kvpair, Storage, StorageIter, Value}; 2 | use dashmap::{mapref::one::Ref, DashMap}; 3 | 4 | /// 使用 DashMap 构建的 MemTable,实现了 Storage trait 5 | #[derive(Clone, Debug, Default)] 6 | pub struct MemTable { 7 | tables: DashMap>, 8 | } 9 | 10 | impl MemTable { 11 | /// 创建一个缺省的 MemTable 12 | pub fn new() -> Self { 13 | Self::default() 14 | } 15 | 16 | /// 如果名为 name 的 hash table 不存在,则创建,否则返回 17 | fn get_or_create_table(&self, name: &str) -> Ref> { 18 | match self.tables.get(name) { 19 | Some(table) => table, 20 | None => { 21 | let entry = self.tables.entry(name.into()).or_default(); 22 | entry.downgrade() 23 | } 24 | } 25 | } 26 | } 27 | 28 | impl Storage for MemTable { 29 | fn get(&self, table: &str, key: &str) -> Result, KvError> { 30 | let table = self.get_or_create_table(table); 31 | Ok(table.get(key).map(|v| v.value().clone())) 32 | } 33 | 34 | fn set(&self, table: &str, key: String, value: Value) -> Result, KvError> { 35 | let table = self.get_or_create_table(table); 36 | Ok(table.insert(key, value)) 37 | } 38 | 39 | fn contains(&self, table: &str, key: &str) -> Result { 40 | let table = self.get_or_create_table(table); 41 | Ok(table.contains_key(key)) 42 | } 43 | 44 | fn del(&self, table: &str, key: &str) -> Result, KvError> { 45 | let table = self.get_or_create_table(table); 46 | Ok(table.remove(key).map(|(_k, v)| v)) 47 | } 48 | 49 | fn get_all(&self, table: &str) -> Result, KvError> { 50 | let table = self.get_or_create_table(table); 51 | Ok(table 52 | .iter() 53 | .map(|v| Kvpair::new(v.key(), v.value().clone())) 54 | .collect()) 55 | } 56 | 57 | fn get_iter(&self, table: &str) -> Result>, KvError> { 58 | // 使用 clone() 来获取 table 的 snapshot 59 | let table = self.get_or_create_table(table).clone(); 60 | let iter = StorageIter::new(table.into_iter()); 61 | Ok(Box::new(iter)) 62 | } 63 | } 64 | 65 | impl From<(String, Value)> for Kvpair { 66 | fn from(data: (String, Value)) -> Self { 67 | Kvpair::new(data.0, data.1) 68 | } 69 | } 70 | 71 | #[cfg(test)] 72 | mod tests { 73 | use super::*; 74 | 75 | #[test] 76 | fn get_or_create_table_should_work() { 77 | let store = MemTable::new(); 78 | assert!(!store.tables.contains_key("t1")); 79 | store.get_or_create_table("t1"); 80 | assert!(store.tables.contains_key("t1")); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/storage/mod.rs: -------------------------------------------------------------------------------- 1 | mod memory; 2 | mod sleddb; 3 | 4 | pub use memory::MemTable; 5 | pub use sleddb::SledDb; 6 | 7 | use crate::{KvError, Kvpair, Value}; 8 | 9 | /// 对存储的抽象,我们不关心数据存在哪儿,但需要定义外界如何和存储打交道 10 | pub trait Storage: Send + Sync + 'static { 11 | /// 从一个 HashTable 里获取一个 key 的 value 12 | fn get(&self, table: &str, key: &str) -> Result, KvError>; 13 | /// 从一个 HashTable 里设置一个 key 的 value,返回旧的 value 14 | fn set(&self, table: &str, key: String, value: Value) -> Result, KvError>; 15 | /// 查看 HashTable 中是否有 key 16 | fn contains(&self, table: &str, key: &str) -> Result; 17 | /// 从 HashTable 中删除一个 key 18 | fn del(&self, table: &str, key: &str) -> Result, KvError>; 19 | /// 遍历 HashTable,返回所有 kv pair(这个接口不好) 20 | fn get_all(&self, table: &str) -> Result, KvError>; 21 | /// 遍历 HashTable,返回 kv pair 的 Iterator 22 | fn get_iter(&self, table: &str) -> Result>, KvError>; 23 | } 24 | 25 | /// 提供 Storage iterator,这样 trait 的实现者只需要 26 | /// 把它们的 iterator 提供给 StorageIter,然后它们保证 27 | /// next() 传出的类型实现了 Into 即可 28 | pub struct StorageIter { 29 | data: T, 30 | } 31 | 32 | impl StorageIter { 33 | pub fn new(data: T) -> Self { 34 | Self { data } 35 | } 36 | } 37 | 38 | impl Iterator for StorageIter 39 | where 40 | T: Iterator, 41 | T::Item: Into, 42 | { 43 | type Item = Kvpair; 44 | 45 | fn next(&mut self) -> Option { 46 | self.data.next().map(|v| v.into()) 47 | } 48 | } 49 | 50 | #[cfg(test)] 51 | mod tests { 52 | use tempfile::tempdir; 53 | 54 | use super::*; 55 | 56 | #[test] 57 | fn memtable_basic_interface_should_work() { 58 | let store = MemTable::new(); 59 | test_basi_interface(store); 60 | } 61 | 62 | #[test] 63 | fn memtable_get_all_should_work() { 64 | let store = MemTable::new(); 65 | test_get_all(store); 66 | } 67 | 68 | #[test] 69 | fn memtable_iter_should_work() { 70 | let store = MemTable::new(); 71 | test_get_iter(store); 72 | } 73 | 74 | #[test] 75 | fn sleddb_basic_interface_should_work() { 76 | let dir = tempdir().unwrap(); 77 | let store = SledDb::new(dir); 78 | test_basi_interface(store); 79 | } 80 | 81 | #[test] 82 | fn sleddb_get_all_should_work() { 83 | let dir = tempdir().unwrap(); 84 | let store = SledDb::new(dir); 85 | test_get_all(store); 86 | } 87 | 88 | #[test] 89 | fn sleddb_iter_should_work() { 90 | let dir = tempdir().unwrap(); 91 | let store = SledDb::new(dir); 92 | test_get_iter(store); 93 | } 94 | 95 | fn test_basi_interface(store: impl Storage) { 96 | // 第一次 set 会创建 table,插入 key 并返回 None(之前没值) 97 | let v = store.set("t1", "hello".into(), "world".into()); 98 | assert!(v.unwrap().is_none()); 99 | // 再次 set 同样的 key 会更新,并返回之前的值 100 | let v1 = store.set("t1", "hello".into(), "world1".into()); 101 | assert_eq!(v1.unwrap(), Some("world".into())); 102 | 103 | // get 存在的 key 会得到最新的值 104 | let v = store.get("t1", "hello"); 105 | assert_eq!(v.unwrap(), Some("world1".into())); 106 | 107 | // get 不存在的 key 或者 table 会得到 None 108 | assert_eq!(None, store.get("t1", "hello1").unwrap()); 109 | assert!(store.get("t2", "hello1").unwrap().is_none()); 110 | 111 | // contains 纯在的 key 返回 true,否则 false 112 | assert!(store.contains("t1", "hello").unwrap()); 113 | assert!(!store.contains("t1", "hello1").unwrap()); 114 | assert!(!store.contains("t2", "hello").unwrap()); 115 | 116 | // del 存在的 key 返回之前的值 117 | let v = store.del("t1", "hello"); 118 | assert_eq!(v.unwrap(), Some("world1".into())); 119 | 120 | // del 不存在的 key 或 table 返回 None 121 | assert_eq!(None, store.del("t1", "hello1").unwrap()); 122 | assert_eq!(None, store.del("t2", "hello").unwrap()); 123 | } 124 | 125 | fn test_get_all(store: impl Storage) { 126 | store.set("t2", "k1".into(), "v1".into()).unwrap(); 127 | store.set("t2", "k2".into(), "v2".into()).unwrap(); 128 | let mut data = store.get_all("t2").unwrap(); 129 | data.sort_by(|a, b| a.partial_cmp(b).unwrap()); 130 | assert_eq!( 131 | data, 132 | vec![ 133 | Kvpair::new("k1", "v1".into()), 134 | Kvpair::new("k2", "v2".into()) 135 | ] 136 | ) 137 | } 138 | 139 | fn test_get_iter(store: impl Storage) { 140 | store.set("t2", "k1".into(), "v1".into()).unwrap(); 141 | store.set("t2", "k2".into(), "v2".into()).unwrap(); 142 | let mut data: Vec<_> = store.get_iter("t2").unwrap().collect(); 143 | data.sort_by(|a, b| a.partial_cmp(b).unwrap()); 144 | assert_eq!( 145 | data, 146 | vec![ 147 | Kvpair::new("k1", "v1".into()), 148 | Kvpair::new("k2", "v2".into()) 149 | ] 150 | ) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /src/storage/sleddb.rs: -------------------------------------------------------------------------------- 1 | use sled::{Db, IVec}; 2 | use std::{convert::TryInto, path::Path, str}; 3 | 4 | use crate::{KvError, Kvpair, Storage, StorageIter, Value}; 5 | 6 | #[derive(Debug)] 7 | pub struct SledDb(Db); 8 | 9 | impl SledDb { 10 | pub fn new(path: impl AsRef) -> Self { 11 | Self(sled::open(path).unwrap()) 12 | } 13 | 14 | // 在 sleddb 里,因为它可以 scan_prefix,我们用 prefix 15 | // 来模拟一个 table。当然,还可以用其它方案。 16 | fn get_full_key(table: &str, key: &str) -> String { 17 | format!("{}:{}", table, key) 18 | } 19 | 20 | // 遍历 table 的 key 时,我们直接把 prefix: 当成 table 21 | fn get_table_prefix(table: &str) -> String { 22 | format!("{}:", table) 23 | } 24 | } 25 | 26 | /// 把 Option> flip 成 Result, E> 27 | /// 从这个函数里,你可以看到函数式编程的优雅 28 | fn flip(x: Option>) -> Result, E> { 29 | x.map_or(Ok(None), |v| v.map(Some)) 30 | } 31 | 32 | impl Storage for SledDb { 33 | fn get(&self, table: &str, key: &str) -> Result, KvError> { 34 | let name = SledDb::get_full_key(table, key); 35 | let result = self.0.get(name.as_bytes())?.map(|v| v.as_ref().try_into()); 36 | flip(result) 37 | } 38 | 39 | fn set(&self, table: &str, key: String, value: Value) -> Result, KvError> { 40 | let name = SledDb::get_full_key(table, &key); 41 | let data: Vec = value.try_into()?; 42 | 43 | let result = self.0.insert(name, data)?.map(|v| v.as_ref().try_into()); 44 | flip(result) 45 | } 46 | 47 | fn contains(&self, table: &str, key: &str) -> Result { 48 | let name = SledDb::get_full_key(table, key); 49 | 50 | Ok(self.0.contains_key(name)?) 51 | } 52 | 53 | fn del(&self, table: &str, key: &str) -> Result, KvError> { 54 | let name = SledDb::get_full_key(table, key); 55 | 56 | let result = self.0.remove(name)?.map(|v| v.as_ref().try_into()); 57 | flip(result) 58 | } 59 | 60 | fn get_all(&self, table: &str) -> Result, KvError> { 61 | let prefix = SledDb::get_table_prefix(table); 62 | let result = self.0.scan_prefix(prefix).map(|v| v.into()).collect(); 63 | 64 | Ok(result) 65 | } 66 | 67 | fn get_iter(&self, table: &str) -> Result>, KvError> { 68 | let prefix = SledDb::get_table_prefix(table); 69 | let iter = StorageIter::new(self.0.scan_prefix(prefix)); 70 | Ok(Box::new(iter)) 71 | } 72 | } 73 | 74 | impl From> for Kvpair { 75 | fn from(v: Result<(IVec, IVec), sled::Error>) -> Self { 76 | match v { 77 | Ok((k, v)) => match v.as_ref().try_into() { 78 | Ok(v) => Kvpair::new(ivec_to_key(k.as_ref()), v), 79 | Err(_) => Kvpair::default(), 80 | }, 81 | _ => Kvpair::default(), 82 | } 83 | } 84 | } 85 | 86 | fn ivec_to_key(ivec: &[u8]) -> &str { 87 | let s = str::from_utf8(ivec).unwrap(); 88 | let mut iter = s.split(':'); 89 | iter.next(); 90 | iter.next().unwrap() 91 | } 92 | -------------------------------------------------------------------------------- /tests/server.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use simple_kv::{ 3 | start_server_with_config, start_yamux_client_with_config, AppStream, ClientConfig, 4 | CommandRequest, ServerConfig, StorageConfig, 5 | }; 6 | use std::time::Duration; 7 | use tokio::time; 8 | 9 | #[tokio::test] 10 | async fn yamux_server_client_full_tests() -> Result<()> { 11 | let addr = "127.0.0.1:10086"; 12 | 13 | let mut config: ServerConfig = toml::from_str(include_str!("../fixtures/server.conf"))?; 14 | config.general.addr = addr.into(); 15 | config.storage = StorageConfig::MemTable; 16 | 17 | // 启动服务器 18 | tokio::spawn(async move { 19 | start_server_with_config(&config).await.unwrap(); 20 | }); 21 | 22 | time::sleep(Duration::from_millis(10)).await; 23 | let mut config: ClientConfig = toml::from_str(include_str!("../fixtures/client.conf"))?; 24 | config.general.addr = addr.into(); 25 | 26 | let mut ctrl = start_yamux_client_with_config(&config).await.unwrap(); 27 | let mut stream = ctrl.open_stream().await?; 28 | 29 | // 生成一个 HSET 命令 30 | let cmd = CommandRequest::new_hset("table1", "hello", "world".to_string().into()); 31 | stream.execute_unary(&cmd).await?; 32 | 33 | // 生成一个 HGET 命令 34 | let cmd = CommandRequest::new_hget("table1", "hello"); 35 | let data = stream.execute_unary(&cmd).await?; 36 | 37 | assert_eq!(data.status, 200); 38 | assert_eq!(data.values, &["world".into()]); 39 | 40 | Ok(()) 41 | } 42 | -------------------------------------------------------------------------------- /tools/gen_cert.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use certify::{generate_ca, generate_cert, load_ca, CertType, CA}; 3 | use tokio::fs; 4 | 5 | struct CertPem { 6 | cert_type: CertType, 7 | cert: String, 8 | key: String, 9 | } 10 | 11 | #[tokio::main] 12 | async fn main() -> Result<()> { 13 | let pem = create_ca()?; 14 | gen_files(&pem).await?; 15 | let ca = load_ca(&pem.cert, &pem.key)?; 16 | let pem = create_cert(&ca, &["kvserver.acme.inc"], "Acme KV server", false)?; 17 | gen_files(&pem).await?; 18 | let pem = create_cert(&ca, &[], "awesome-device-id", true)?; 19 | gen_files(&pem).await?; 20 | Ok(()) 21 | } 22 | 23 | fn create_ca() -> Result { 24 | let (cert, key) = generate_ca( 25 | &["acme.inc"], 26 | "CN", 27 | "Acme Inc.", 28 | "Acme CA", 29 | None, 30 | Some(10 * 365), 31 | )?; 32 | Ok(CertPem { 33 | cert_type: CertType::CA, 34 | cert, 35 | key, 36 | }) 37 | } 38 | 39 | fn create_cert(ca: &CA, domains: &[&str], cn: &str, is_client: bool) -> Result { 40 | let (days, cert_type) = if is_client { 41 | (Some(365), CertType::Client) 42 | } else { 43 | (Some(5 * 365), CertType::Server) 44 | }; 45 | let (cert, key) = generate_cert(ca, domains, "CN", "Acme Inc.", cn, None, is_client, days)?; 46 | 47 | Ok(CertPem { 48 | cert_type, 49 | cert, 50 | key, 51 | }) 52 | } 53 | 54 | async fn gen_files(pem: &CertPem) -> Result<()> { 55 | let name = match pem.cert_type { 56 | CertType::Client => "client", 57 | CertType::Server => "server", 58 | CertType::CA => "ca", 59 | }; 60 | fs::write(format!("fixtures/{}.cert", name), pem.cert.as_bytes()).await?; 61 | fs::write(format!("fixtures/{}.key", name), pem.key.as_bytes()).await?; 62 | Ok(()) 63 | } 64 | -------------------------------------------------------------------------------- /tools/gen_config.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use simple_kv::{ 3 | ClientConfig, ClientTlsConfig, GeneralConfig, LogConfig, NetworkType, RotationConfig, 4 | ServerConfig, ServerTlsConfig, StorageConfig, 5 | }; 6 | use std::fs; 7 | 8 | fn main() -> Result<()> { 9 | const CA_CERT: &str = include_str!("../fixtures/ca.cert"); 10 | const SERVER_CERT: &str = include_str!("../fixtures/server.cert"); 11 | const SERVER_KEY: &str = include_str!("../fixtures/server.key"); 12 | 13 | let general_config = GeneralConfig { 14 | addr: "127.0.0.1:9527".into(), 15 | network: NetworkType::Tcp, 16 | }; 17 | let server_config = ServerConfig { 18 | storage: StorageConfig::SledDb("/tmp/kv_server".into()), 19 | general: general_config.clone(), 20 | tls: ServerTlsConfig { 21 | cert: SERVER_CERT.into(), 22 | key: SERVER_KEY.into(), 23 | ca: None, 24 | }, 25 | log: LogConfig { 26 | enable_jaeger: false, 27 | enable_log_file: false, 28 | log_level: "info".to_string(), 29 | path: "/tmp/kv-log".into(), 30 | rotation: RotationConfig::Daily, 31 | }, 32 | }; 33 | 34 | fs::write( 35 | "fixtures/server.conf", 36 | toml::to_string_pretty(&server_config)?, 37 | )?; 38 | 39 | let client_config = ClientConfig { 40 | general: general_config, 41 | 42 | tls: ClientTlsConfig { 43 | identity: None, 44 | ca: Some(CA_CERT.into()), 45 | domain: "kvserver.acme.inc".into(), 46 | }, 47 | }; 48 | 49 | fs::write( 50 | "fixtures/client.conf", 51 | toml::to_string_pretty(&client_config)?, 52 | )?; 53 | 54 | Ok(()) 55 | } 56 | --------------------------------------------------------------------------------