├── .config └── nextest.toml ├── .github └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── r2dma ├── Cargo.lock ├── Cargo.toml ├── README.md ├── build.rs ├── examples │ └── r2dma.rs └── src │ ├── buf │ ├── aligned_buffer.rs │ ├── buffer_pool.rs │ ├── mod.rs │ └── rdma_buffer.rs │ ├── core │ ├── comp_queues.rs │ ├── config.rs │ ├── devices.rs │ ├── event_loop.rs │ ├── mod.rs │ └── queue_pair.rs │ ├── error.rs │ ├── lib.rs │ └── verbs.rs ├── r2pc-demo ├── Cargo.toml └── src │ ├── bin │ ├── client.rs │ └── server.rs │ └── lib.rs ├── r2pc-macro ├── Cargo.toml └── src │ └── lib.rs ├── r2pc ├── Cargo.toml ├── examples │ └── r2pc_info.rs ├── src │ ├── client.rs │ ├── connection_pool.rs │ ├── constants.rs │ ├── context.rs │ ├── core_service │ │ ├── info_service.rs │ │ └── mod.rs │ ├── error.rs │ ├── lib.rs │ ├── meta.rs │ ├── server.rs │ └── transport.rs └── tests │ ├── test_concurrent.rs │ └── test_service.rs └── rust-toolchain.toml /.config/nextest.toml: -------------------------------------------------------------------------------- 1 | [profile.default.junit] 2 | path = "junit.xml" 3 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ "main", "dev" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Setup Soft-RoCE 20 | run: | 21 | KERNEL_VERSION=$(uname -r | cut -d '-' -f 1) 22 | KERNEL_NAME="linux-${KERNEL_VERSION%'.0'}" 23 | DOWNLOAD_LINK="https://cdn.kernel.org/pub/linux/kernel/v${KERNEL_VERSION%%.*}.x/${KERNEL_NAME}.tar.xz" 24 | ETHERNET_CARD=$(ip link | awk -F ": " '$0 !~ "lo|vir|wl|^[^0-9]"{print $2;getline}' | head -1) 25 | echo "kernel version is ${KERNEL_VERSION}, download link is ${DOWNLOAD_LINK}, ethernet card is ${ETHERNET_CARD}" 26 | wget -q $DOWNLOAD_LINK -O /tmp/$KERNEL_NAME.tar.xz 27 | tar xf /tmp/$KERNEL_NAME.tar.xz --directory=/tmp 28 | RXE_PATH="/tmp/$KERNEL_NAME/drivers/infiniband/sw/rxe" 29 | sed 's/$(CONFIG_RDMA_RXE)/m/g' $RXE_PATH/Makefile > $RXE_PATH/Kbuild 30 | make -C /lib/modules/$(uname -r)/build M=$RXE_PATH modules -j 31 | sudo modprobe ib_core 32 | sudo modprobe rdma_ucm 33 | sudo insmod $RXE_PATH/rdma_rxe.ko 34 | sudo rdma link add rxe_0 type rxe netdev $ETHERNET_CARD 35 | rdma link 36 | 37 | - name: Run tests 38 | run: | 39 | sudo prlimit --pid $$ -l=unlimited && ulimit -a 40 | sudo apt install -y pkg-config libibverbs-dev ibverbs-utils 41 | ibv_devinfo -d rxe_0 -v 42 | cargo install cargo-llvm-cov cargo-nextest 43 | cargo llvm-cov nextest 44 | cargo llvm-cov report --cobertura --output-path target/llvm-cov-target/cobertura.xml 45 | 46 | - name: Upload coverage reports to Codecov 47 | uses: codecov/codecov-action@v4 48 | with: 49 | token: ${{ secrets.CODECOV_TOKEN }} 50 | slug: SF-Zhou/r2dma 51 | files: target/llvm-cov-target/cobertura.xml 52 | 53 | - name: Upload test results to Codecov 54 | uses: codecov/test-results-action@v1 55 | with: 56 | token: ${{ secrets.CODECOV_TOKEN }} 57 | slug: SF-Zhou/r2dma 58 | files: target/nextest/default/junit.xml 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["r2dma", "r2pc", "r2pc-demo", "r2pc-macro"] 3 | resolver = "2" 4 | 5 | [workspace.package] 6 | authors = ["SF-Zhou "] 7 | edition = "2021" 8 | homepage = "https://github.com/SF-Zhou/r2dma" 9 | repository = "https://github.com/SF-Zhou/r2dma" 10 | description = "A Rust RDMA library" 11 | license = "MIT OR Apache-2.0" 12 | 13 | [workspace.dependencies] 14 | derse = "0" 15 | libc = "0" 16 | thiserror = "2" 17 | tracing = "0" 18 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SF-Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # r2dma 2 | 3 | [![Rust](https://github.com/SF-Zhou/r2dma/actions/workflows/rust.yml/badge.svg)](https://github.com/SF-Zhou/r2dma/actions/workflows/rust.yml) 4 | [![codecov](https://codecov.io/gh/SF-Zhou/r2dma/graph/badge.svg?token=AB5ULDT77Z)](https://codecov.io/gh/SF-Zhou/r2dma) 5 | [![crates.io](https://img.shields.io/crates/v/r2dma.svg)](https://crates.io/crates/r2dma) 6 | [![stability-wip](https://img.shields.io/badge/stability-wip-lightgrey.svg)](https://github.com/mkenney/software-guides/blob/master/STABILITY-BADGES.md#work-in-progress) 7 | 8 | A Rust RDMA library. 9 | -------------------------------------------------------------------------------- /r2dma/Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "aho-corasick" 7 | version = "1.1.3" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" 10 | dependencies = [ 11 | "memchr", 12 | ] 13 | 14 | [[package]] 15 | name = "bindgen" 16 | version = "0.70.1" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f" 19 | dependencies = [ 20 | "bitflags", 21 | "cexpr", 22 | "clang-sys", 23 | "itertools", 24 | "log", 25 | "prettyplease", 26 | "proc-macro2", 27 | "quote", 28 | "regex", 29 | "rustc-hash", 30 | "shlex", 31 | "syn", 32 | ] 33 | 34 | [[package]] 35 | name = "bitflags" 36 | version = "2.6.0" 37 | source = "registry+https://github.com/rust-lang/crates.io-index" 38 | checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" 39 | 40 | [[package]] 41 | name = "cexpr" 42 | version = "0.6.0" 43 | source = "registry+https://github.com/rust-lang/crates.io-index" 44 | checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" 45 | dependencies = [ 46 | "nom", 47 | ] 48 | 49 | [[package]] 50 | name = "cfg-if" 51 | version = "1.0.0" 52 | source = "registry+https://github.com/rust-lang/crates.io-index" 53 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 54 | 55 | [[package]] 56 | name = "clang-sys" 57 | version = "1.8.1" 58 | source = "registry+https://github.com/rust-lang/crates.io-index" 59 | checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" 60 | dependencies = [ 61 | "glob", 62 | "libc", 63 | "libloading", 64 | ] 65 | 66 | [[package]] 67 | name = "derse" 68 | version = "0.1.31" 69 | source = "registry+https://github.com/rust-lang/crates.io-index" 70 | checksum = "7aa1a95f27196eccf8e87e0a7be6ee79d90bafc1c392302d942011384ee2ee4a" 71 | dependencies = [ 72 | "derse-derive", 73 | "thiserror 1.0.69", 74 | ] 75 | 76 | [[package]] 77 | name = "derse-derive" 78 | version = "0.1.15" 79 | source = "registry+https://github.com/rust-lang/crates.io-index" 80 | checksum = "effbaf2997637f96a3b52502d1b4d8ca7a7f4a55a5206ef6d4d71712fa2368be" 81 | dependencies = [ 82 | "proc-macro-crate", 83 | "proc-macro2", 84 | "quote", 85 | "syn", 86 | ] 87 | 88 | [[package]] 89 | name = "either" 90 | version = "1.13.0" 91 | source = "registry+https://github.com/rust-lang/crates.io-index" 92 | checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" 93 | 94 | [[package]] 95 | name = "equivalent" 96 | version = "1.0.1" 97 | source = "registry+https://github.com/rust-lang/crates.io-index" 98 | checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" 99 | 100 | [[package]] 101 | name = "glob" 102 | version = "0.3.1" 103 | source = "registry+https://github.com/rust-lang/crates.io-index" 104 | checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" 105 | 106 | [[package]] 107 | name = "hashbrown" 108 | version = "0.15.1" 109 | source = "registry+https://github.com/rust-lang/crates.io-index" 110 | checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" 111 | 112 | [[package]] 113 | name = "indexmap" 114 | version = "2.6.0" 115 | source = "registry+https://github.com/rust-lang/crates.io-index" 116 | checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" 117 | dependencies = [ 118 | "equivalent", 119 | "hashbrown", 120 | ] 121 | 122 | [[package]] 123 | name = "itertools" 124 | version = "0.13.0" 125 | source = "registry+https://github.com/rust-lang/crates.io-index" 126 | checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" 127 | dependencies = [ 128 | "either", 129 | ] 130 | 131 | [[package]] 132 | name = "libc" 133 | version = "0.2.164" 134 | source = "registry+https://github.com/rust-lang/crates.io-index" 135 | checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" 136 | 137 | [[package]] 138 | name = "libloading" 139 | version = "0.8.5" 140 | source = "registry+https://github.com/rust-lang/crates.io-index" 141 | checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" 142 | dependencies = [ 143 | "cfg-if", 144 | "windows-targets", 145 | ] 146 | 147 | [[package]] 148 | name = "log" 149 | version = "0.4.22" 150 | source = "registry+https://github.com/rust-lang/crates.io-index" 151 | checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" 152 | 153 | [[package]] 154 | name = "memchr" 155 | version = "2.7.4" 156 | source = "registry+https://github.com/rust-lang/crates.io-index" 157 | checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" 158 | 159 | [[package]] 160 | name = "minimal-lexical" 161 | version = "0.2.1" 162 | source = "registry+https://github.com/rust-lang/crates.io-index" 163 | checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" 164 | 165 | [[package]] 166 | name = "nom" 167 | version = "7.1.3" 168 | source = "registry+https://github.com/rust-lang/crates.io-index" 169 | checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" 170 | dependencies = [ 171 | "memchr", 172 | "minimal-lexical", 173 | ] 174 | 175 | [[package]] 176 | name = "once_cell" 177 | version = "1.20.2" 178 | source = "registry+https://github.com/rust-lang/crates.io-index" 179 | checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" 180 | 181 | [[package]] 182 | name = "pin-project-lite" 183 | version = "0.2.15" 184 | source = "registry+https://github.com/rust-lang/crates.io-index" 185 | checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" 186 | 187 | [[package]] 188 | name = "pkg-config" 189 | version = "0.3.31" 190 | source = "registry+https://github.com/rust-lang/crates.io-index" 191 | checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" 192 | 193 | [[package]] 194 | name = "prettyplease" 195 | version = "0.2.25" 196 | source = "registry+https://github.com/rust-lang/crates.io-index" 197 | checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" 198 | dependencies = [ 199 | "proc-macro2", 200 | "syn", 201 | ] 202 | 203 | [[package]] 204 | name = "proc-macro-crate" 205 | version = "3.2.0" 206 | source = "registry+https://github.com/rust-lang/crates.io-index" 207 | checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" 208 | dependencies = [ 209 | "toml_edit", 210 | ] 211 | 212 | [[package]] 213 | name = "proc-macro2" 214 | version = "1.0.92" 215 | source = "registry+https://github.com/rust-lang/crates.io-index" 216 | checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" 217 | dependencies = [ 218 | "unicode-ident", 219 | ] 220 | 221 | [[package]] 222 | name = "quote" 223 | version = "1.0.37" 224 | source = "registry+https://github.com/rust-lang/crates.io-index" 225 | checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" 226 | dependencies = [ 227 | "proc-macro2", 228 | ] 229 | 230 | [[package]] 231 | name = "r2dma" 232 | version = "0.1.1" 233 | dependencies = [ 234 | "bindgen", 235 | "derse", 236 | "libc", 237 | "pkg-config", 238 | "thiserror 2.0.3", 239 | "tracing", 240 | ] 241 | 242 | [[package]] 243 | name = "regex" 244 | version = "1.11.1" 245 | source = "registry+https://github.com/rust-lang/crates.io-index" 246 | checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" 247 | dependencies = [ 248 | "aho-corasick", 249 | "memchr", 250 | "regex-automata", 251 | "regex-syntax", 252 | ] 253 | 254 | [[package]] 255 | name = "regex-automata" 256 | version = "0.4.9" 257 | source = "registry+https://github.com/rust-lang/crates.io-index" 258 | checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" 259 | dependencies = [ 260 | "aho-corasick", 261 | "memchr", 262 | "regex-syntax", 263 | ] 264 | 265 | [[package]] 266 | name = "regex-syntax" 267 | version = "0.8.5" 268 | source = "registry+https://github.com/rust-lang/crates.io-index" 269 | checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" 270 | 271 | [[package]] 272 | name = "rustc-hash" 273 | version = "1.1.0" 274 | source = "registry+https://github.com/rust-lang/crates.io-index" 275 | checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" 276 | 277 | [[package]] 278 | name = "shlex" 279 | version = "1.3.0" 280 | source = "registry+https://github.com/rust-lang/crates.io-index" 281 | checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" 282 | 283 | [[package]] 284 | name = "syn" 285 | version = "2.0.89" 286 | source = "registry+https://github.com/rust-lang/crates.io-index" 287 | checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" 288 | dependencies = [ 289 | "proc-macro2", 290 | "quote", 291 | "unicode-ident", 292 | ] 293 | 294 | [[package]] 295 | name = "thiserror" 296 | version = "1.0.69" 297 | source = "registry+https://github.com/rust-lang/crates.io-index" 298 | checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" 299 | dependencies = [ 300 | "thiserror-impl 1.0.69", 301 | ] 302 | 303 | [[package]] 304 | name = "thiserror" 305 | version = "2.0.3" 306 | source = "registry+https://github.com/rust-lang/crates.io-index" 307 | checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" 308 | dependencies = [ 309 | "thiserror-impl 2.0.3", 310 | ] 311 | 312 | [[package]] 313 | name = "thiserror-impl" 314 | version = "1.0.69" 315 | source = "registry+https://github.com/rust-lang/crates.io-index" 316 | checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" 317 | dependencies = [ 318 | "proc-macro2", 319 | "quote", 320 | "syn", 321 | ] 322 | 323 | [[package]] 324 | name = "thiserror-impl" 325 | version = "2.0.3" 326 | source = "registry+https://github.com/rust-lang/crates.io-index" 327 | checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" 328 | dependencies = [ 329 | "proc-macro2", 330 | "quote", 331 | "syn", 332 | ] 333 | 334 | [[package]] 335 | name = "toml_datetime" 336 | version = "0.6.8" 337 | source = "registry+https://github.com/rust-lang/crates.io-index" 338 | checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" 339 | 340 | [[package]] 341 | name = "toml_edit" 342 | version = "0.22.22" 343 | source = "registry+https://github.com/rust-lang/crates.io-index" 344 | checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" 345 | dependencies = [ 346 | "indexmap", 347 | "toml_datetime", 348 | "winnow", 349 | ] 350 | 351 | [[package]] 352 | name = "tracing" 353 | version = "0.1.40" 354 | source = "registry+https://github.com/rust-lang/crates.io-index" 355 | checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" 356 | dependencies = [ 357 | "pin-project-lite", 358 | "tracing-attributes", 359 | "tracing-core", 360 | ] 361 | 362 | [[package]] 363 | name = "tracing-attributes" 364 | version = "0.1.27" 365 | source = "registry+https://github.com/rust-lang/crates.io-index" 366 | checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" 367 | dependencies = [ 368 | "proc-macro2", 369 | "quote", 370 | "syn", 371 | ] 372 | 373 | [[package]] 374 | name = "tracing-core" 375 | version = "0.1.32" 376 | source = "registry+https://github.com/rust-lang/crates.io-index" 377 | checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" 378 | dependencies = [ 379 | "once_cell", 380 | ] 381 | 382 | [[package]] 383 | name = "unicode-ident" 384 | version = "1.0.14" 385 | source = "registry+https://github.com/rust-lang/crates.io-index" 386 | checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" 387 | 388 | [[package]] 389 | name = "windows-targets" 390 | version = "0.52.6" 391 | source = "registry+https://github.com/rust-lang/crates.io-index" 392 | checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" 393 | dependencies = [ 394 | "windows_aarch64_gnullvm", 395 | "windows_aarch64_msvc", 396 | "windows_i686_gnu", 397 | "windows_i686_gnullvm", 398 | "windows_i686_msvc", 399 | "windows_x86_64_gnu", 400 | "windows_x86_64_gnullvm", 401 | "windows_x86_64_msvc", 402 | ] 403 | 404 | [[package]] 405 | name = "windows_aarch64_gnullvm" 406 | version = "0.52.6" 407 | source = "registry+https://github.com/rust-lang/crates.io-index" 408 | checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" 409 | 410 | [[package]] 411 | name = "windows_aarch64_msvc" 412 | version = "0.52.6" 413 | source = "registry+https://github.com/rust-lang/crates.io-index" 414 | checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" 415 | 416 | [[package]] 417 | name = "windows_i686_gnu" 418 | version = "0.52.6" 419 | source = "registry+https://github.com/rust-lang/crates.io-index" 420 | checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" 421 | 422 | [[package]] 423 | name = "windows_i686_gnullvm" 424 | version = "0.52.6" 425 | source = "registry+https://github.com/rust-lang/crates.io-index" 426 | checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" 427 | 428 | [[package]] 429 | name = "windows_i686_msvc" 430 | version = "0.52.6" 431 | source = "registry+https://github.com/rust-lang/crates.io-index" 432 | checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" 433 | 434 | [[package]] 435 | name = "windows_x86_64_gnu" 436 | version = "0.52.6" 437 | source = "registry+https://github.com/rust-lang/crates.io-index" 438 | checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" 439 | 440 | [[package]] 441 | name = "windows_x86_64_gnullvm" 442 | version = "0.52.6" 443 | source = "registry+https://github.com/rust-lang/crates.io-index" 444 | checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" 445 | 446 | [[package]] 447 | name = "windows_x86_64_msvc" 448 | version = "0.52.6" 449 | source = "registry+https://github.com/rust-lang/crates.io-index" 450 | checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" 451 | 452 | [[package]] 453 | name = "winnow" 454 | version = "0.6.20" 455 | source = "registry+https://github.com/rust-lang/crates.io-index" 456 | checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" 457 | dependencies = [ 458 | "memchr", 459 | ] 460 | -------------------------------------------------------------------------------- /r2dma/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "r2dma" 3 | version = "0.1.7" 4 | authors.workspace = true 5 | edition.workspace = true 6 | homepage.workspace = true 7 | repository.workspace = true 8 | description.workspace = true 9 | license.workspace = true 10 | 11 | [dependencies] 12 | derse.workspace = true 13 | libc.workspace = true 14 | thiserror.workspace = true 15 | tracing.workspace = true 16 | 17 | [build-dependencies] 18 | bindgen = "0" 19 | pkg-config = "0" 20 | 21 | [dev-dependencies] 22 | clap = { version = "4", features = ["derive"] } 23 | tracing-subscriber = { version = "0", features = ["chrono"] } 24 | -------------------------------------------------------------------------------- /r2dma/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /r2dma/build.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashSet; 2 | use std::env; 3 | use std::path::PathBuf; 4 | 5 | fn main() { 6 | let lib = pkg_config::Config::new() 7 | .statik(false) 8 | .probe("libibverbs") 9 | .unwrap_or_else(|_| panic!("please install libibverbs-dev and pkg-config")); 10 | 11 | let mut include_paths = lib.include_paths.into_iter().collect::>(); 12 | include_paths.insert(PathBuf::from("/usr/include")); 13 | 14 | let builder = bindgen::Builder::default() 15 | .clang_args(include_paths.iter().map(|p| format!("-I{:?}", p))) 16 | .header_contents("header.h", "#include ") 17 | .derive_copy(true) 18 | .derive_debug(true) 19 | .derive_default(true) 20 | .generate_comments(false) 21 | .prepend_enum_name(false) 22 | .formatter(bindgen::Formatter::Rustfmt) 23 | .size_t_is_usize(true) 24 | .translate_enum_integer_types(true) 25 | .layout_tests(false) 26 | .default_enum_style(bindgen::EnumVariation::Rust { 27 | non_exhaustive: false, 28 | }) 29 | .opaque_type("pthread_cond_t") 30 | .opaque_type("pthread_mutex_t") 31 | .allowlist_type("ibv_access_flags") 32 | .allowlist_type("ibv_comp_channel") 33 | .allowlist_type("ibv_context") 34 | .allowlist_type("ibv_cq") 35 | .allowlist_type("ibv_device") 36 | .allowlist_type("ibv_gid") 37 | .allowlist_type("ibv_mr") 38 | .allowlist_type("ibv_pd") 39 | .allowlist_type("ibv_port_attr") 40 | .allowlist_type("ibv_qp") 41 | .allowlist_type("ibv_qp_attr_mask") 42 | .allowlist_type("ibv_qp_init_attr") 43 | .allowlist_type("ibv_send_flags") 44 | .allowlist_type("ibv_wc") 45 | .allowlist_type("ibv_wc_flags") 46 | .allowlist_type("ibv_wc_status") 47 | .allowlist_function("ibv_ack_cq_events") 48 | .allowlist_function("ibv_alloc_pd") 49 | .allowlist_function("ibv_close_device") 50 | .allowlist_function("ibv_create_comp_channel") 51 | .allowlist_function("ibv_create_cq") 52 | .allowlist_function("ibv_create_qp") 53 | .allowlist_function("ibv_dealloc_pd") 54 | .allowlist_function("ibv_dereg_mr") 55 | .allowlist_function("ibv_destroy_comp_channel") 56 | .allowlist_function("ibv_destroy_cq") 57 | .allowlist_function("ibv_destroy_qp") 58 | .allowlist_function("ibv_free_device_list") 59 | .allowlist_function("ibv_get_cq_event") 60 | .allowlist_function("ibv_get_device_guid") 61 | .allowlist_function("ibv_get_device_list") 62 | .allowlist_function("ibv_modify_qp") 63 | .allowlist_function("ibv_req_notify_cq") 64 | .allowlist_function("ibv_poll_cq") 65 | .allowlist_function("ibv_post_recv") 66 | .allowlist_function("ibv_post_send") 67 | .allowlist_function("ibv_query_device") 68 | .allowlist_function("ibv_query_gid") 69 | .allowlist_function("ibv_query_port") 70 | .allowlist_function("ibv_open_device") 71 | .allowlist_function("ibv_reg_mr") 72 | .bitfield_enum("ibv_access_flags") 73 | .bitfield_enum("ibv_send_flags") 74 | .bitfield_enum("ibv_wc_flags") 75 | .bitfield_enum("ibv_qp_attr_mask") 76 | .no_copy("ibv_context") 77 | .no_copy("ibv_cq") 78 | .no_copy("ibv_qp") 79 | .no_copy("ibv_srq") 80 | .no_debug("ibv_device"); 81 | 82 | builder 83 | .generate() 84 | .expect("Unable to generate bindings") 85 | .write_to_file(PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs")) 86 | .expect("Couldn't write bindings!"); 87 | } 88 | -------------------------------------------------------------------------------- /r2dma/examples/r2dma.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use r2dma::{Result, *}; 3 | 4 | #[derive(Parser, Debug)] 5 | #[command(version, about, long_about = None)] 6 | struct Args { 7 | /// device filter by name. 8 | #[arg(long)] 9 | pub device_filter: Vec, 10 | 11 | /// enable gid type filter (IB or RoCE v2). 12 | #[arg(long, default_value_t = false)] 13 | pub gid_type_filter: bool, 14 | 15 | /// RoCE v2 skip link local address. 16 | #[arg(long, default_value_t = false)] 17 | pub skip_link_local_addr: bool, 18 | 19 | /// enable port state filter. 20 | #[arg(long, default_value_t = false)] 21 | pub skip_inactive_port: bool, 22 | 23 | /// enable verbose logging. 24 | #[arg(long, short, default_value_t = false)] 25 | pub verbose: bool, 26 | } 27 | 28 | fn main() -> Result<()> { 29 | let args = Args::parse(); 30 | tracing_subscriber::fmt() 31 | .with_max_level(if args.verbose { 32 | tracing::Level::DEBUG 33 | } else { 34 | tracing::Level::INFO 35 | }) 36 | .with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) 37 | .init(); 38 | 39 | let mut config = DeviceConfig::default(); 40 | config.device_filter.extend(args.device_filter); 41 | if args.gid_type_filter { 42 | config.gid_type_filter = [GidType::IB, GidType::RoCEv2].into(); 43 | } 44 | if args.skip_link_local_addr { 45 | config.roce_v2_skip_link_local_addr = true; 46 | } 47 | 48 | let devices = Devices::open(&config)?; 49 | for device in &devices { 50 | println!("device: {:#?}", device.info()); 51 | } 52 | 53 | Ok(()) 54 | } 55 | -------------------------------------------------------------------------------- /r2dma/src/buf/aligned_buffer.rs: -------------------------------------------------------------------------------- 1 | use crate::{Error, Result}; 2 | use std::alloc::Layout; 3 | 4 | pub const ALIGN_SIZE: usize = 4096; 5 | 6 | pub struct AlignedBuffer(&'static mut [u8]); 7 | 8 | impl AlignedBuffer { 9 | pub fn new(size: usize) -> Result { 10 | assert_ne!(size, 0, "the buffer length cannot be zero!"); 11 | let size = size.next_multiple_of(ALIGN_SIZE); 12 | unsafe { 13 | let layout = Layout::from_size_align_unchecked(size, ALIGN_SIZE); 14 | let ptr = std::alloc::alloc(layout); 15 | if ptr.is_null() { 16 | Err(Error::AllocMemoryFailed) 17 | } else { 18 | Ok(Self(std::slice::from_raw_parts_mut(ptr, size))) 19 | } 20 | } 21 | } 22 | } 23 | 24 | impl Drop for AlignedBuffer { 25 | fn drop(&mut self) { 26 | unsafe { 27 | let layout = Layout::from_size_align_unchecked(self.0.len(), ALIGN_SIZE); 28 | std::alloc::dealloc(self.0.as_mut_ptr(), layout); 29 | } 30 | } 31 | } 32 | 33 | impl std::ops::Deref for AlignedBuffer { 34 | type Target = [u8]; 35 | 36 | #[inline(always)] 37 | fn deref(&self) -> &Self::Target { 38 | self.0 39 | } 40 | } 41 | 42 | impl std::ops::DerefMut for AlignedBuffer { 43 | #[inline(always)] 44 | fn deref_mut(&mut self) -> &mut Self::Target { 45 | self.0 46 | } 47 | } 48 | 49 | unsafe impl Send for AlignedBuffer {} 50 | unsafe impl Sync for AlignedBuffer {} 51 | -------------------------------------------------------------------------------- /r2dma/src/buf/buffer_pool.rs: -------------------------------------------------------------------------------- 1 | use crate::*; 2 | use std::{ 3 | ops::{Deref, DerefMut}, 4 | sync::{Arc, Mutex}, 5 | }; 6 | 7 | pub struct BufferPool { 8 | buffer: RegisteredBuffer, 9 | block_size: usize, 10 | free_list: Mutex>, 11 | } 12 | 13 | pub struct Buffer { 14 | pool: Arc, 15 | idx: usize, 16 | } 17 | 18 | impl Drop for Buffer { 19 | fn drop(&mut self) { 20 | self.pool.deallocate(self.idx); 21 | } 22 | } 23 | 24 | impl Deref for Buffer { 25 | type Target = [u8]; 26 | 27 | fn deref(&self) -> &Self::Target { 28 | let start = self.idx * self.pool.block_size; 29 | &self.pool.buffer[start..start + self.pool.block_size] 30 | } 31 | } 32 | 33 | impl DerefMut for Buffer { 34 | fn deref_mut(&mut self) -> &mut Self::Target { 35 | let buf: &[u8] = self.deref(); 36 | unsafe { std::slice::from_raw_parts_mut(buf.as_ptr() as *mut u8, buf.len()) } 37 | } 38 | } 39 | 40 | impl Buffer { 41 | pub fn lkey(&self, device: &Device) -> u32 { 42 | self.pool.buffer.lkey(device.index()) 43 | } 44 | 45 | pub fn rkey(&self, device: &Device) -> u32 { 46 | self.pool.buffer.rkey(device.index()) 47 | } 48 | } 49 | 50 | impl BufferPool { 51 | pub fn create(block_size: usize, block_count: usize, devices: &Devices) -> Result> { 52 | let buffer_size = block_size * block_count; 53 | let buffer = RegisteredBuffer::create(devices, buffer_size)?; 54 | let free_list = Mutex::new((0..block_count).collect()); 55 | Ok(Arc::new(Self { 56 | buffer, 57 | block_size, 58 | free_list, 59 | })) 60 | } 61 | 62 | pub fn allocate(self: &Arc) -> Result { 63 | let mut free_list = self.free_list.lock().unwrap(); 64 | match free_list.pop() { 65 | Some(idx) => Ok(Buffer { 66 | pool: self.clone(), 67 | idx, 68 | }), 69 | None => Err(Error::AllocMemoryFailed), 70 | } 71 | } 72 | 73 | fn deallocate(&self, idx: usize) { 74 | let mut free_list = self.free_list.lock().unwrap(); 75 | free_list.push(idx); 76 | } 77 | } 78 | 79 | #[cfg(test)] 80 | mod tests { 81 | use super::*; 82 | 83 | #[test] 84 | fn test_buffer() { 85 | const LEN: usize = 1 << 20; 86 | let devices = Devices::availables().unwrap(); 87 | let buffer_pool = BufferPool::create(LEN, 32, &devices).unwrap(); 88 | let mut buffer = buffer_pool.allocate().unwrap(); 89 | assert_eq!(buffer.len(), LEN); 90 | buffer.fill(1); 91 | 92 | let mut another = buffer_pool.allocate().unwrap(); 93 | assert_ne!(buffer.as_ptr(), another.as_ptr()); 94 | another.fill(2); 95 | drop(another); 96 | drop(buffer); 97 | 98 | let buffer = buffer_pool.allocate().unwrap(); 99 | assert_eq!(buffer.len(), LEN); 100 | buffer.iter().all(|&x| x == 1); 101 | 102 | let another = buffer_pool.allocate().unwrap(); 103 | assert_eq!(another.len(), LEN); 104 | another.iter().all(|&x| x == 2); 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /r2dma/src/buf/mod.rs: -------------------------------------------------------------------------------- 1 | mod aligned_buffer; 2 | pub use aligned_buffer::AlignedBuffer; 3 | 4 | mod rdma_buffer; 5 | pub use rdma_buffer::RegisteredBuffer; 6 | 7 | mod buffer_pool; 8 | pub use buffer_pool::{Buffer, BufferPool}; 9 | -------------------------------------------------------------------------------- /r2dma/src/buf/rdma_buffer.rs: -------------------------------------------------------------------------------- 1 | use crate::*; 2 | 3 | struct RawMemoryRegion(*mut verbs::ibv_mr); 4 | impl std::ops::Deref for RawMemoryRegion { 5 | type Target = verbs::ibv_mr; 6 | 7 | fn deref(&self) -> &Self::Target { 8 | unsafe { &*self.0 } 9 | } 10 | } 11 | impl Drop for RawMemoryRegion { 12 | fn drop(&mut self) { 13 | let _ = unsafe { verbs::ibv_dereg_mr(self.0) }; 14 | } 15 | } 16 | unsafe impl Send for RawMemoryRegion {} 17 | unsafe impl Sync for RawMemoryRegion {} 18 | 19 | pub struct RegisteredBuffer { 20 | memory_regions: Vec, 21 | aligned_buffer: AlignedBuffer, 22 | _devices: Devices, 23 | } 24 | 25 | impl RegisteredBuffer { 26 | pub fn create(devices: &Devices, size: usize) -> Result { 27 | let mut buf = AlignedBuffer::new(size)?; 28 | let mut memory_regions = Vec::with_capacity(devices.len()); 29 | for device in devices { 30 | let mr = unsafe { 31 | verbs::ibv_reg_mr( 32 | device.pd_ptr(), 33 | buf.as_mut_ptr() as _, 34 | buf.len(), 35 | verbs::ACCESS_FLAGS as _, 36 | ) 37 | }; 38 | memory_regions.push(RawMemoryRegion(mr)); 39 | } 40 | Ok(Self { 41 | memory_regions, 42 | aligned_buffer: buf, 43 | _devices: devices.clone(), 44 | }) 45 | } 46 | 47 | pub fn lkey(&self, index: usize) -> u32 { 48 | self.memory_regions[index].lkey 49 | } 50 | 51 | pub fn rkey(&self, index: usize) -> u32 { 52 | self.memory_regions[index].lkey 53 | } 54 | } 55 | 56 | impl std::ops::Deref for RegisteredBuffer { 57 | type Target = [u8]; 58 | 59 | #[inline(always)] 60 | fn deref(&self) -> &Self::Target { 61 | &self.aligned_buffer 62 | } 63 | } 64 | 65 | impl std::fmt::Debug for RegisteredBuffer { 66 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 67 | f.debug_struct("RegisteredBuffer") 68 | .field("addr", &self.aligned_buffer.as_ptr()) 69 | .field("len", &self.aligned_buffer.len()) 70 | .field("num_mrs", &self.memory_regions.len()) 71 | .finish() 72 | } 73 | } 74 | 75 | #[cfg(test)] 76 | mod tests { 77 | use super::*; 78 | 79 | #[test] 80 | fn test_memory_region() { 81 | let size = 4096usize; 82 | let devices = Devices::availables().unwrap(); 83 | let registered_buffer = RegisteredBuffer::create(&devices, size).unwrap(); 84 | assert_eq!(registered_buffer.len(), size); 85 | println!("{:#?}", registered_buffer); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /r2dma/src/core/comp_queues.rs: -------------------------------------------------------------------------------- 1 | use super::Devices; 2 | use crate::{verbs, Error, Result}; 3 | use std::sync::Arc; 4 | 5 | struct RawCompQueue(*mut verbs::ibv_cq); 6 | impl std::ops::Deref for RawCompQueue { 7 | type Target = verbs::ibv_cq; 8 | 9 | fn deref(&self) -> &Self::Target { 10 | unsafe { &*self.0 } 11 | } 12 | } 13 | impl Drop for RawCompQueue { 14 | fn drop(&mut self) { 15 | let _ = unsafe { verbs::ibv_destroy_cq(self.0) }; 16 | } 17 | } 18 | unsafe impl Send for RawCompQueue {} 19 | unsafe impl Sync for RawCompQueue {} 20 | 21 | pub struct CompQueues { 22 | comp_queues: Vec, 23 | pub cqe: usize, 24 | _devices: Devices, 25 | } 26 | 27 | impl CompQueues { 28 | pub fn create(devices: &Devices, max_cqe: u32) -> Result> { 29 | let mut comp_queues = Vec::with_capacity(devices.len()); 30 | for device in devices { 31 | let ptr = unsafe { 32 | verbs::ibv_create_cq( 33 | device.context_ptr(), 34 | max_cqe as _, 35 | std::ptr::null_mut(), 36 | std::ptr::null_mut(), 37 | 0, 38 | ) 39 | }; 40 | if ptr.is_null() { 41 | return Err(Error::IBCreateCompQueueFail(std::io::Error::last_os_error())); 42 | } 43 | comp_queues.push(RawCompQueue(ptr)); 44 | } 45 | let cqe = comp_queues.first().unwrap().cqe as usize; 46 | 47 | let this = Self { 48 | comp_queues, 49 | cqe, 50 | _devices: devices.clone(), 51 | }; 52 | Ok(Arc::new(this)) 53 | } 54 | 55 | pub(crate) fn comp_queue_ptr(&self, device_index: usize) -> *mut verbs::ibv_cq { 56 | self.comp_queues[device_index].0 57 | } 58 | 59 | pub fn num_entries(&self) -> usize { 60 | self.comp_queues.len() * self.cqe 61 | } 62 | 63 | pub fn poll_cq<'a>(&self, wcs: &'a mut [verbs::ibv_wc]) -> Result<&'a mut [verbs::ibv_wc]> { 64 | assert!(wcs.len() >= self.comp_queues.len()); 65 | let mut offset = 0usize; 66 | let num_entries = (wcs.len() / self.comp_queues.len()) as i32; 67 | for comp_queue in &self.comp_queues { 68 | let num = unsafe { 69 | verbs::ibv_poll_cq(comp_queue.0, num_entries, wcs.as_mut_ptr().add(offset) as _) 70 | }; 71 | if num >= 0 { 72 | offset += num as usize; 73 | } else { 74 | tracing::error!( 75 | "poll comp queue failed: {}", 76 | std::io::Error::last_os_error() 77 | ); 78 | } 79 | } 80 | Ok(&mut wcs[..offset]) 81 | } 82 | } 83 | 84 | impl std::fmt::Debug for CompQueues { 85 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 86 | f.debug_struct("CompQueue") 87 | .field("cqe", &self.cqe) 88 | .field("num_cqs", &self.comp_queues.len()) 89 | .finish() 90 | } 91 | } 92 | 93 | #[cfg(test)] 94 | mod tests { 95 | use super::*; 96 | 97 | #[test] 98 | fn test_comp_queue() { 99 | let max_cqe = 1024; 100 | let devices = Devices::availables().unwrap(); 101 | let comp_queues = CompQueues::create(&devices, max_cqe).unwrap(); 102 | println!("{:#?}", comp_queues); 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /r2dma/src/core/config.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashSet; 2 | 3 | #[derive(Debug, PartialEq, Eq, Hash)] 4 | pub enum GidType { 5 | IB, 6 | RoCEv1, 7 | RoCEv2, 8 | Other(String), 9 | } 10 | 11 | #[derive(Debug, Default)] 12 | pub struct Config { 13 | pub device: DeviceConfig, 14 | } 15 | 16 | #[derive(Debug, Default)] 17 | pub struct DeviceConfig { 18 | pub device_filter: HashSet, 19 | pub gid_type_filter: HashSet, 20 | pub skip_inactive_port: bool, 21 | pub roce_v2_skip_link_local_addr: bool, 22 | } 23 | -------------------------------------------------------------------------------- /r2dma/src/core/devices.rs: -------------------------------------------------------------------------------- 1 | use super::{DeviceConfig, GidType}; 2 | use crate::{verbs, Error, Result}; 3 | use std::{ 4 | ffi::{c_int, CStr, OsStr}, 5 | ops::Deref, 6 | os::unix::ffi::OsStrExt, 7 | path::{Path, PathBuf}, 8 | sync::Arc, 9 | }; 10 | 11 | struct RawDeviceList { 12 | ptr: *mut *mut verbs::ibv_device, 13 | num_devices: usize, 14 | } 15 | 16 | impl RawDeviceList { 17 | fn available() -> Result { 18 | let mut num_devices: c_int = 0; 19 | let ptr = unsafe { verbs::ibv_get_device_list(&mut num_devices) }; 20 | if ptr.is_null() { 21 | return Err(Error::IBGetDeviceListFail(std::io::Error::last_os_error())); 22 | } 23 | if num_devices == 0 { 24 | return Err(Error::IBDeviceNotFound); 25 | } 26 | Ok(Self { 27 | ptr, 28 | num_devices: num_devices as usize, 29 | }) 30 | } 31 | } 32 | 33 | impl Drop for RawDeviceList { 34 | fn drop(&mut self) { 35 | unsafe { verbs::ibv_free_device_list(self.ptr) }; 36 | } 37 | } 38 | 39 | impl Deref for RawDeviceList { 40 | type Target = [*mut verbs::ibv_device]; 41 | 42 | fn deref(&self) -> &Self::Target { 43 | unsafe { std::slice::from_raw_parts(self.ptr, self.num_devices) } 44 | } 45 | } 46 | 47 | unsafe impl Send for RawDeviceList {} 48 | unsafe impl Sync for RawDeviceList {} 49 | 50 | struct RawContext(*mut verbs::ibv_context); 51 | impl Drop for RawContext { 52 | fn drop(&mut self) { 53 | let _ = unsafe { verbs::ibv_close_device(self.0) }; 54 | } 55 | } 56 | impl RawContext { 57 | fn query_device(&self) -> Result { 58 | let mut device_attr = verbs::ibv_device_attr::default(); 59 | let ret = unsafe { verbs::ibv_query_device(self.0, &mut device_attr) }; 60 | if ret != 0 { 61 | Err(Error::IBQueryDeviceFail(std::io::Error::last_os_error())) 62 | } else { 63 | Ok(device_attr) 64 | } 65 | } 66 | 67 | fn query_port(&self, port_num: u8) -> Result { 68 | let mut port_attr = std::mem::MaybeUninit::::uninit(); 69 | let ret = unsafe { verbs::ibv_query_port(self.0, port_num, port_attr.as_mut_ptr() as _) }; 70 | if ret == 0 { 71 | Ok(unsafe { port_attr.assume_init() }) 72 | } else { 73 | Err(Error::IBQueryPortFail(std::io::Error::last_os_error())) 74 | } 75 | } 76 | 77 | fn query_gid(&self, port_num: u8, gid_index: u16) -> Result { 78 | let mut gid = verbs::ibv_gid::default(); 79 | let ret = unsafe { verbs::ibv_query_gid(self.0, port_num as _, gid_index as _, &mut gid) }; 80 | if ret == 0 && !gid.is_null() { 81 | Ok(gid) 82 | } else { 83 | Err(Error::IBQueryGidFail(std::io::Error::last_os_error())) 84 | } 85 | } 86 | 87 | fn query_gid_type( 88 | &self, 89 | port_num: u8, 90 | gid_index: u16, 91 | ibdev_path: &Path, 92 | port_attr: &verbs::ibv_port_attr, 93 | ) -> Result { 94 | let path = ibdev_path.join(format!("ports/{}/gid_attrs/types/{}", port_num, gid_index)); 95 | match std::fs::read_to_string(path) { 96 | Ok(content) => { 97 | if content == "IB/RoCE v1\n" { 98 | if port_attr.link_layer == verbs::IBV_LINK_LAYER::INFINIBAND as u8 { 99 | Ok(GidType::IB) 100 | } else { 101 | Ok(GidType::RoCEv1) 102 | } 103 | } else if content == "RoCE v2\n" { 104 | Ok(GidType::RoCEv2) 105 | } else { 106 | Ok(GidType::Other(content.trim().to_string())) 107 | } 108 | } 109 | Err(err) => Err(Error::IBQueryGidTypeFail(err)), 110 | } 111 | } 112 | } 113 | unsafe impl Send for RawContext {} 114 | unsafe impl Sync for RawContext {} 115 | 116 | pub struct RawProtectionDomain(*mut verbs::ibv_pd); 117 | impl Drop for RawProtectionDomain { 118 | fn drop(&mut self) { 119 | let _ = unsafe { verbs::ibv_dealloc_pd(self.0) }; 120 | } 121 | } 122 | unsafe impl Send for RawProtectionDomain {} 123 | unsafe impl Sync for RawProtectionDomain {} 124 | 125 | #[derive(Debug, Default)] 126 | #[allow(unused)] 127 | pub struct DeviceInfo { 128 | pub index: usize, 129 | pub name: String, 130 | pub guid: u64, 131 | pub ibdev_path: PathBuf, 132 | pub device_attr: verbs::ibv_device_attr, 133 | pub ports: Vec, 134 | } 135 | 136 | #[allow(unused)] 137 | pub struct Device { 138 | protection_domain: RawProtectionDomain, 139 | context: RawContext, 140 | device: *mut verbs::ibv_device, 141 | list: Arc, 142 | info: DeviceInfo, 143 | } 144 | 145 | unsafe impl Send for Device {} 146 | unsafe impl Sync for Device {} 147 | 148 | #[derive(Debug)] 149 | pub struct Port { 150 | pub port_num: u8, 151 | /// The attributes of the port. 152 | pub port_attr: verbs::ibv_port_attr, 153 | /// The GID (Global Identifier) list of the port. 154 | pub gids: Vec<(u16, verbs::ibv_gid, GidType)>, 155 | } 156 | 157 | #[allow(unused)] 158 | impl Device { 159 | fn open( 160 | list: Arc, 161 | device: *mut verbs::ibv_device, 162 | index: usize, 163 | config: &DeviceConfig, 164 | ) -> Result { 165 | let name = unsafe { CStr::from_ptr((*device).name.as_ptr()) } 166 | .to_string_lossy() 167 | .to_string(); 168 | let guid = u64::from_be(unsafe { verbs::ibv_get_device_guid(device) }); 169 | let str = unsafe { CStr::from_ptr((*device).ibdev_path.as_ptr()) }; 170 | let ibdev_path = PathBuf::from(OsStr::from_bytes(str.to_bytes())); 171 | 172 | let context = RawContext(unsafe { 173 | let context = verbs::ibv_open_device(device); 174 | if context.is_null() { 175 | return Err(Error::IBOpenDeviceFail(std::io::Error::last_os_error())); 176 | } 177 | context 178 | }); 179 | 180 | let protection_domain = RawProtectionDomain(unsafe { 181 | let protection_domain = verbs::ibv_alloc_pd(context.0); 182 | if protection_domain.is_null() { 183 | return Err(Error::IBAllocPDFail(std::io::Error::last_os_error())); 184 | } 185 | protection_domain 186 | }); 187 | 188 | let mut device = Self { 189 | protection_domain, 190 | context, 191 | device, 192 | list, 193 | info: DeviceInfo { 194 | index, 195 | name, 196 | guid, 197 | ibdev_path, 198 | ..Default::default() 199 | }, 200 | }; 201 | device.update_attr(config)?; 202 | 203 | Ok(device) 204 | } 205 | 206 | fn update_attr(&mut self, config: &DeviceConfig) -> Result<()> { 207 | // 1. query device attr. 208 | let device_attr = self.context.query_device()?; 209 | 210 | let mut ports = vec![]; 211 | for port_num in 1..=device_attr.phys_port_cnt { 212 | let port_attr = self.context.query_port(port_num)?; 213 | if port_attr.state != verbs::ibv_port_state::IBV_PORT_ACTIVE 214 | && config.skip_inactive_port 215 | { 216 | continue; 217 | } 218 | 219 | let mut gids = vec![]; 220 | for gid_index in 0..port_attr.gid_tbl_len as u16 { 221 | if let Ok(gid) = self.context.query_gid(port_num, gid_index) { 222 | let gid_type = self.context.query_gid_type( 223 | port_num, 224 | gid_index, 225 | &self.info.ibdev_path, 226 | &port_attr, 227 | )?; 228 | if !config.gid_type_filter.is_empty() 229 | && !config.gid_type_filter.contains(&gid_type) 230 | { 231 | continue; 232 | } 233 | 234 | if config.roce_v2_skip_link_local_addr && gid_type == GidType::RoCEv2 { 235 | let ip = gid.as_ipv6(); 236 | if ip.is_unicast_link_local() { 237 | continue; 238 | } 239 | } 240 | 241 | gids.push((gid_index, gid, gid_type)) 242 | } 243 | } 244 | 245 | ports.push(Port { 246 | port_num, 247 | port_attr, 248 | gids, 249 | }); 250 | } 251 | 252 | self.info.device_attr = device_attr; 253 | self.info.ports = ports; 254 | 255 | Ok(()) 256 | } 257 | 258 | pub(crate) fn device_ptr(&self) -> *mut verbs::ibv_device { 259 | self.device 260 | } 261 | 262 | pub(crate) fn context_ptr(&self) -> *mut verbs::ibv_context { 263 | self.context.0 264 | } 265 | 266 | pub(crate) fn pd_ptr(&self) -> *mut verbs::ibv_pd { 267 | self.protection_domain.0 268 | } 269 | 270 | pub fn index(&self) -> usize { 271 | self.info.index 272 | } 273 | 274 | pub fn info(&self) -> &DeviceInfo { 275 | &self.info 276 | } 277 | } 278 | 279 | impl std::fmt::Debug for Device { 280 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 281 | std::fmt::Debug::fmt(&self.info, f) 282 | } 283 | } 284 | 285 | #[derive(Clone)] 286 | pub struct Devices(Arc>); 287 | 288 | impl Devices { 289 | pub fn availables() -> Result { 290 | Self::open(&Default::default()) 291 | } 292 | 293 | pub fn open(config: &DeviceConfig) -> Result { 294 | let list = Arc::new(RawDeviceList::available()?); 295 | let mut devices = Vec::with_capacity(list.len()); 296 | for &device in list.iter() { 297 | let index = devices.len(); 298 | let device = Device::open(list.clone(), device, index, config)?; 299 | if !config.device_filter.is_empty() && !config.device_filter.contains(&device.info.name) 300 | { 301 | tracing::debug!( 302 | "skip device {} by filter: {:?}", 303 | device.info.name, 304 | config.device_filter 305 | ); 306 | continue; 307 | } 308 | 309 | devices.push(device); 310 | } 311 | if devices.is_empty() { 312 | Err(Error::IBDeviceNotFound) 313 | } else { 314 | Ok(Devices(Arc::new(devices))) 315 | } 316 | } 317 | } 318 | 319 | impl Deref for Devices { 320 | type Target = [Device]; 321 | 322 | fn deref(&self) -> &Self::Target { 323 | &self.0 324 | } 325 | } 326 | 327 | impl<'a> IntoIterator for &'a Devices { 328 | type Item = &'a Device; 329 | type IntoIter = std::slice::Iter<'a, Device>; 330 | 331 | fn into_iter(self) -> Self::IntoIter { 332 | self.iter() 333 | } 334 | } 335 | 336 | #[cfg(test)] 337 | mod tests { 338 | use super::*; 339 | 340 | #[test] 341 | fn list_devices() { 342 | let devices = Devices::availables().unwrap(); 343 | assert!(!devices.is_empty()); 344 | for device in &devices { 345 | println!("{:#?}", device); 346 | } 347 | } 348 | } 349 | -------------------------------------------------------------------------------- /r2dma/src/core/event_loop.rs: -------------------------------------------------------------------------------- 1 | use super::{CompQueues, Devices}; 2 | use crate::{verbs, Result}; 3 | use std::sync::{ 4 | atomic::{AtomicBool, Ordering}, 5 | Arc, 6 | }; 7 | 8 | pub struct EventLoopState { 9 | stopping: AtomicBool, 10 | comp_queues: Arc, 11 | } 12 | 13 | pub struct EventLoop { 14 | state: Arc, 15 | handle: Option>, 16 | } 17 | 18 | impl EventLoop { 19 | pub fn create(devices: &Devices) -> Result { 20 | let max_cqe = 32; 21 | let comp_queues = CompQueues::create(devices, max_cqe)?; 22 | let state = Arc::new(EventLoopState { 23 | stopping: AtomicBool::new(false), 24 | comp_queues, 25 | }); 26 | 27 | let handle = std::thread::spawn({ 28 | let state = state.clone(); 29 | move || EventLoop::run(state) 30 | }); 31 | 32 | Ok(EventLoop { 33 | state, 34 | handle: Some(handle), 35 | }) 36 | } 37 | 38 | pub fn stop_and_join(&mut self) { 39 | self.state.stopping.store(true, Ordering::Release); 40 | if let Some(handle) = self.handle.take() { 41 | handle.join().unwrap(); 42 | } 43 | } 44 | 45 | pub fn run(state: Arc) { 46 | let comp_queues = state.comp_queues.clone(); 47 | let num_entiries = comp_queues.num_entries(); 48 | let mut wcs = vec![verbs::ibv_wc::default(); num_entiries]; 49 | 50 | while !state.stopping.load(Ordering::Acquire) { 51 | // poll for events. 52 | let wcs = comp_queues.poll_cq(&mut wcs).unwrap(); 53 | if wcs.is_empty() { 54 | std::thread::sleep(std::time::Duration::from_millis(1)); 55 | continue; 56 | } 57 | 58 | // handle events. 59 | for wc in wcs { 60 | tracing::info!( 61 | "wc is id {}, result {}, status {:?}", 62 | wc.wr_id, 63 | wc.byte_len, 64 | wc.status 65 | ); 66 | } 67 | } 68 | } 69 | } 70 | 71 | impl Drop for EventLoop { 72 | fn drop(&mut self) { 73 | self.stop_and_join(); 74 | } 75 | } 76 | 77 | #[cfg(test)] 78 | mod tests { 79 | use super::*; 80 | 81 | #[test] 82 | fn test_event_loop() { 83 | let devices = Devices::availables().unwrap(); 84 | let event_loop = EventLoop::create(&devices).unwrap(); 85 | std::thread::sleep(std::time::Duration::from_millis(200)); 86 | drop(event_loop); 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /r2dma/src/core/mod.rs: -------------------------------------------------------------------------------- 1 | mod config; 2 | pub use config::{Config, DeviceConfig, GidType}; 3 | 4 | mod devices; 5 | pub use devices::{Device, Devices}; 6 | 7 | mod comp_queues; 8 | pub use comp_queues::CompQueues; 9 | 10 | mod queue_pair; 11 | pub use queue_pair::{Endpoint, QueuePair}; 12 | 13 | mod event_loop; 14 | pub use event_loop::EventLoop; 15 | -------------------------------------------------------------------------------- /r2dma/src/core/queue_pair.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use crate::{verbs, Error, Result}; 3 | use derse::{Deserialize, Serialize}; 4 | use std::{ffi::c_int, ops::Deref, sync::Arc}; 5 | 6 | #[derive(Debug, Deserialize, Serialize)] 7 | pub struct Endpoint { 8 | pub qp_num: u32, 9 | pub lid: u16, 10 | pub gid: verbs::ibv_gid, 11 | } 12 | 13 | struct RawQueuePair(*mut verbs::ibv_qp); 14 | impl Drop for RawQueuePair { 15 | fn drop(&mut self) { 16 | let _ = unsafe { verbs::ibv_destroy_qp(self.0) }; 17 | } 18 | } 19 | unsafe impl Send for RawQueuePair {} 20 | unsafe impl Sync for RawQueuePair {} 21 | 22 | pub struct QueuePair { 23 | queue_pair: RawQueuePair, 24 | _comp_queues: Arc, 25 | _device_index: usize, 26 | _devices: Devices, 27 | } 28 | 29 | impl QueuePair { 30 | pub fn create( 31 | devices: &Devices, 32 | device_index: usize, 33 | comp_queues: &Arc, 34 | cap: verbs::ibv_qp_cap, 35 | ) -> Result { 36 | let mut attr = verbs::ibv_qp_init_attr { 37 | qp_context: std::ptr::null_mut(), 38 | send_cq: comp_queues.comp_queue_ptr(device_index), 39 | recv_cq: comp_queues.comp_queue_ptr(device_index), 40 | srq: std::ptr::null_mut(), 41 | cap, 42 | qp_type: verbs::ibv_qp_type::IBV_QPT_RC, 43 | sq_sig_all: 0, 44 | }; 45 | let ptr = unsafe { verbs::ibv_create_qp(devices[device_index].pd_ptr(), &mut attr) }; 46 | if ptr.is_null() { 47 | return Err(Error::IBCreateQueuePairFail(std::io::Error::last_os_error())); 48 | } 49 | Ok(Self { 50 | queue_pair: RawQueuePair(ptr), 51 | _comp_queues: comp_queues.clone(), 52 | _device_index: device_index, 53 | _devices: devices.clone(), 54 | }) 55 | } 56 | 57 | pub fn init(&mut self, port_num: u8, pkey_index: u16) -> Result<()> { 58 | let mut attr = verbs::ibv_qp_attr { 59 | qp_state: verbs::ibv_qp_state::IBV_QPS_INIT, 60 | pkey_index, 61 | port_num, 62 | qp_access_flags: verbs::ACCESS_FLAGS, 63 | ..Default::default() 64 | }; 65 | 66 | const MASK: verbs::ibv_qp_attr_mask = verbs::ibv_qp_attr_mask( 67 | verbs::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX.0 68 | | verbs::ibv_qp_attr_mask::IBV_QP_STATE.0 69 | | verbs::ibv_qp_attr_mask::IBV_QP_PORT.0 70 | | verbs::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS.0, 71 | ); 72 | 73 | self.modify_qp(&mut attr, MASK) 74 | } 75 | 76 | pub fn ready_to_recv(&self, remote: &Endpoint) -> Result<()> { 77 | let mut attr = verbs::ibv_qp_attr { 78 | qp_state: verbs::ibv_qp_state::IBV_QPS_RTR, 79 | path_mtu: verbs::ibv_mtu::IBV_MTU_512, 80 | dest_qp_num: remote.qp_num, 81 | rq_psn: 0, 82 | max_dest_rd_atomic: 1, 83 | min_rnr_timer: 0x12, 84 | ah_attr: verbs::ibv_ah_attr { 85 | grh: verbs::ibv_global_route { 86 | dgid: remote.gid, 87 | flow_label: 0, 88 | sgid_index: 1, 89 | hop_limit: 0xff, 90 | traffic_class: 0, 91 | }, 92 | dlid: remote.lid, 93 | sl: 0, 94 | src_path_bits: 0, 95 | static_rate: 0, 96 | is_global: 1, 97 | port_num: 1, 98 | }, 99 | ..Default::default() 100 | }; 101 | 102 | const MASK: verbs::ibv_qp_attr_mask = verbs::ibv_qp_attr_mask( 103 | verbs::ibv_qp_attr_mask::IBV_QP_STATE.0 104 | | verbs::ibv_qp_attr_mask::IBV_QP_AV.0 105 | | verbs::ibv_qp_attr_mask::IBV_QP_PATH_MTU.0 106 | | verbs::ibv_qp_attr_mask::IBV_QP_DEST_QPN.0 107 | | verbs::ibv_qp_attr_mask::IBV_QP_RQ_PSN.0 108 | | verbs::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC.0 109 | | verbs::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER.0, 110 | ); 111 | 112 | self.modify_qp(&mut attr, MASK) 113 | } 114 | 115 | pub fn ready_to_send(&self) -> Result<()> { 116 | let mut attr = verbs::ibv_qp_attr { 117 | qp_state: verbs::ibv_qp_state::IBV_QPS_RTS, 118 | timeout: 0x12, 119 | retry_cnt: 6, 120 | rnr_retry: 6, 121 | sq_psn: 0, 122 | max_rd_atomic: 1, 123 | ..Default::default() 124 | }; 125 | 126 | const MASK: verbs::ibv_qp_attr_mask = verbs::ibv_qp_attr_mask( 127 | verbs::ibv_qp_attr_mask::IBV_QP_STATE.0 128 | | verbs::ibv_qp_attr_mask::IBV_QP_TIMEOUT.0 129 | | verbs::ibv_qp_attr_mask::IBV_QP_RETRY_CNT.0 130 | | verbs::ibv_qp_attr_mask::IBV_QP_RNR_RETRY.0 131 | | verbs::ibv_qp_attr_mask::IBV_QP_SQ_PSN.0 132 | | verbs::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC.0, 133 | ); 134 | 135 | self.modify_qp(&mut attr, MASK) 136 | } 137 | 138 | pub fn set_error(&self) { 139 | let mut attr = verbs::ibv_qp_attr { 140 | qp_state: verbs::ibv_qp_state::IBV_QPS_ERR, 141 | ..Default::default() 142 | }; 143 | 144 | const MASK: verbs::ibv_qp_attr_mask = verbs::ibv_qp_attr_mask::IBV_QP_STATE; 145 | 146 | // assuming this operation succeeds. 147 | self.modify_qp(&mut attr, MASK).unwrap() 148 | } 149 | 150 | pub fn post_send(&self, wr: &mut verbs::ibv_send_wr) -> c_int { 151 | let mut bad_wr = std::ptr::null_mut(); 152 | unsafe { verbs::ibv_post_send(self.queue_pair.0, wr, &mut bad_wr) } 153 | } 154 | 155 | pub fn post_recv(&self, wr: &mut verbs::ibv_recv_wr) -> c_int { 156 | let mut bad_wr = std::ptr::null_mut(); 157 | unsafe { verbs::ibv_post_recv(self.queue_pair.0, wr, &mut bad_wr) } 158 | } 159 | 160 | fn modify_qp( 161 | &self, 162 | attr: &mut verbs::ibv_qp_attr, 163 | mask: verbs::ibv_qp_attr_mask, 164 | ) -> Result<()> { 165 | let ret = unsafe { verbs::ibv_modify_qp(self.queue_pair.0, attr, mask.0 as _) }; 166 | if ret == 0_i32 { 167 | Ok(()) 168 | } else { 169 | Err(Error::IBModifyQueuePairFail(std::io::Error::last_os_error())) 170 | } 171 | } 172 | } 173 | 174 | impl Deref for QueuePair { 175 | type Target = verbs::ibv_qp; 176 | 177 | fn deref(&self) -> &Self::Target { 178 | unsafe { &*self.queue_pair.0 } 179 | } 180 | } 181 | 182 | impl std::fmt::Debug for QueuePair { 183 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 184 | f.debug_struct("QueuePair") 185 | .field("handle", &self.handle) 186 | .field("qp_num", &self.qp_num) 187 | .field("state", &self.state) 188 | .field("qp_type", &self.qp_type) 189 | .field("events_completiond", &self.events_completed) 190 | .finish() 191 | } 192 | } 193 | 194 | #[cfg(test)] 195 | mod tests { 196 | use super::*; 197 | use crate::*; 198 | 199 | #[test] 200 | fn test_queue_pair_create() { 201 | let devices = Devices::availables().unwrap(); 202 | let comp_queues = CompQueues::create(&devices, 128).unwrap(); 203 | let cap = verbs::ibv_qp_cap { 204 | max_send_wr: 64, 205 | max_recv_wr: 64, 206 | max_send_sge: 1, 207 | max_recv_sge: 1, 208 | max_inline_data: 0, 209 | }; 210 | let mut queue_pair = QueuePair::create(&devices, 0, &comp_queues, cap).unwrap(); 211 | println!("{:#?}", queue_pair); 212 | 213 | queue_pair.init(1, 0).unwrap(); 214 | queue_pair.set_error(); 215 | } 216 | 217 | #[test] 218 | fn test_queue_pair_send_recv() { 219 | // 1. list all available devices. 220 | let devices = Devices::availables().unwrap(); 221 | 222 | // 2. create two queue pairs. 223 | let cap = verbs::ibv_qp_cap { 224 | max_send_wr: 64, 225 | max_recv_wr: 64, 226 | max_send_sge: 1, 227 | max_recv_sge: 1, 228 | max_inline_data: 0, 229 | }; 230 | 231 | let comp_queues_a = CompQueues::create(&devices, 128).unwrap(); 232 | let mut queue_pair_a = QueuePair::create(&devices, 0, &comp_queues_a, cap).unwrap(); 233 | let comp_queues_b = CompQueues::create(&devices, 128).unwrap(); 234 | let mut queue_pair_b = QueuePair::create(&devices, 0, &comp_queues_b, cap).unwrap(); 235 | 236 | // 3. init all queue pairs. 237 | queue_pair_a.init(1, 0).unwrap(); 238 | queue_pair_b.init(1, 0).unwrap(); 239 | 240 | // 4. post recv wr. 241 | const LEN: usize = 1 << 20; 242 | let buffer_pool = BufferPool::create(LEN, 32, &devices).unwrap(); 243 | 244 | let mut recv_buf = buffer_pool.allocate().unwrap(); 245 | recv_buf.fill(0); 246 | let mut recv_sge = verbs::ibv_sge { 247 | addr: recv_buf.as_ptr() as _, 248 | length: recv_buf.len() as _, 249 | lkey: recv_buf.lkey(&devices[0]), 250 | }; 251 | let mut recv_wr = verbs::ibv_recv_wr { 252 | wr_id: 1, 253 | sg_list: &mut recv_sge as *mut _, 254 | num_sge: 1, 255 | next: std::ptr::null_mut(), 256 | }; 257 | assert_eq!(queue_pair_b.post_recv(&mut recv_wr), 0); 258 | 259 | // 5. connect two queue pairs. 260 | let device = &devices[0]; 261 | let gid = device.info().ports[0].gids[1].1; 262 | queue_pair_a 263 | .ready_to_recv(&Endpoint { 264 | qp_num: queue_pair_b.qp_num, 265 | lid: 0, 266 | gid, 267 | }) 268 | .unwrap(); 269 | queue_pair_b 270 | .ready_to_recv(&Endpoint { 271 | qp_num: queue_pair_a.qp_num, 272 | lid: 0, 273 | gid, 274 | }) 275 | .unwrap(); 276 | 277 | queue_pair_a.ready_to_send().unwrap(); 278 | queue_pair_b.ready_to_send().unwrap(); 279 | 280 | let mut wcs_b = vec![verbs::ibv_wc::default(); 128]; 281 | assert!(comp_queues_b.poll_cq(&mut wcs_b).unwrap().is_empty()); 282 | 283 | // 6. post send wr. 284 | let mut send_buf = buffer_pool.allocate().unwrap(); 285 | send_buf.fill(1); 286 | let mut send_sge = verbs::ibv_sge { 287 | addr: send_buf.as_ptr() as _, 288 | length: send_buf.len() as _, 289 | lkey: recv_buf.lkey(&devices[0]), 290 | }; 291 | let mut send_wr = verbs::ibv_send_wr { 292 | wr_id: 2, 293 | sg_list: &mut send_sge as *mut _, 294 | num_sge: 1, 295 | opcode: verbs::ibv_wr_opcode::IBV_WR_SEND, 296 | send_flags: verbs::ibv_send_flags::IBV_SEND_SIGNALED.0, 297 | ..Default::default() 298 | }; 299 | assert_eq!(queue_pair_a.post_send(&mut send_wr), 0); 300 | 301 | // 7. poll cq. 302 | std::thread::sleep(std::time::Duration::from_millis(100)); 303 | let mut wcs_a = vec![verbs::ibv_wc::default(); 128]; 304 | let comp_a = comp_queues_a.poll_cq(&mut wcs_a).unwrap(); 305 | assert_eq!(comp_a.len(), 1); 306 | assert_eq!(comp_a[0].wr_id, 2); 307 | assert_eq!(comp_a[0].qp_num, queue_pair_a.qp_num); 308 | assert_eq!(comp_a[0].status, verbs::ibv_wc_status::IBV_WC_SUCCESS); 309 | 310 | let comp_b = comp_queues_b.poll_cq(&mut wcs_b).unwrap(); 311 | assert_eq!(comp_b.len(), 1); 312 | assert_eq!(comp_b[0].wr_id, 1); 313 | assert_eq!(comp_b[0].qp_num, queue_pair_b.qp_num); 314 | assert_eq!(comp_b[0].status, verbs::ibv_wc_status::IBV_WC_SUCCESS); 315 | assert_eq!(comp_b[0].byte_len, send_buf.len() as u32); 316 | assert_eq!(recv_buf[..send_buf.len()], send_buf[..]); 317 | } 318 | } 319 | -------------------------------------------------------------------------------- /r2dma/src/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(thiserror::Error, Debug)] 2 | pub enum Error { 3 | #[error("alloc memory failed")] 4 | AllocMemoryFailed, 5 | #[error("ib get deivce list fail: {0}")] 6 | IBGetDeviceListFail(#[source] std::io::Error), 7 | #[error("ib device is not found")] 8 | IBDeviceNotFound, 9 | #[error("ib open device fail: {0}")] 10 | IBOpenDeviceFail(#[source] std::io::Error), 11 | #[error("ib query device fail: {0}")] 12 | IBQueryDeviceFail(#[source] std::io::Error), 13 | #[error("ib query gid fail: {0}")] 14 | IBQueryGidFail(#[source] std::io::Error), 15 | #[error("ib query gid type fail: {0}")] 16 | IBQueryGidTypeFail(#[source] std::io::Error), 17 | #[error("ib query port fail: {0}")] 18 | IBQueryPortFail(#[source] std::io::Error), 19 | #[error("ib allocate protection domain fail: {0}")] 20 | IBAllocPDFail(#[source] std::io::Error), 21 | #[error("ib create completion channel fail: {0}")] 22 | IBCreateCompChannelFail(#[source] std::io::Error), 23 | #[error("ib set completion channel non-block fail: {0}")] 24 | IBSetCompChannelNonBlockFail(#[source] std::io::Error), 25 | #[error("ib get comp queue event fail: {0}")] 26 | IBGetCompQueueEventFail(#[source] std::io::Error), 27 | #[error("ib create comp queue fail: {0}")] 28 | IBCreateCompQueueFail(#[source] std::io::Error), 29 | #[error("ib req notify comp queue fail: {0}")] 30 | IBReqNotifyCompQueueFail(#[source] std::io::Error), 31 | #[error("ib poll comp queue fail: {0}")] 32 | IBPollCompQueueFail(#[source] std::io::Error), 33 | #[error("ib register memory region fail: {0}")] 34 | IBRegMemoryRegionFail(#[source] std::io::Error), 35 | #[error("ib create queue pair fail: {0}")] 36 | IBCreateQueuePairFail(#[source] std::io::Error), 37 | #[error("ib modify queue pair fail: {0}")] 38 | IBModifyQueuePairFail(#[source] std::io::Error), 39 | } 40 | 41 | pub type Result = std::result::Result; 42 | -------------------------------------------------------------------------------- /r2dma/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! r2dma 2 | //! 3 | //! A Rust RDMA library. 4 | pub mod verbs; 5 | 6 | mod core; 7 | pub use core::*; 8 | 9 | mod buf; 10 | pub use buf::*; 11 | 12 | mod error; 13 | pub use error::*; 14 | -------------------------------------------------------------------------------- /r2dma/src/verbs.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | #![allow(deref_nullptr)] 3 | #![allow(non_snake_case, non_camel_case_types, non_upper_case_globals)] 4 | #![allow(clippy::missing_safety_doc, clippy::too_many_arguments)] 5 | 6 | use std::{net::Ipv6Addr, os::raw::c_int}; 7 | 8 | #[repr(transparent)] 9 | pub struct pthread_mutex_t(pub libc::pthread_mutex_t); 10 | 11 | impl std::fmt::Debug for pthread_mutex_t { 12 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 13 | f.debug_tuple("pthread_mutex_t").finish() 14 | } 15 | } 16 | 17 | #[repr(transparent)] 18 | pub struct pthread_cond_t(pub libc::pthread_cond_t); 19 | 20 | impl std::fmt::Debug for pthread_cond_t { 21 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 22 | f.debug_tuple("pthread_cond_t").finish() 23 | } 24 | } 25 | 26 | #[derive(Debug, Clone, Copy)] 27 | #[repr(u8)] 28 | pub enum IBV_LINK_LAYER { 29 | UNSPECIFIED = 0, 30 | INFINIBAND = 1, 31 | ETHERNET = 2, 32 | } 33 | 34 | include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 35 | 36 | #[inline(always)] 37 | pub unsafe fn ibv_req_notify_cq(cq: *mut ibv_cq, solicited_only: c_int) -> c_int { 38 | (*(*cq).context).ops.req_notify_cq.unwrap_unchecked()(cq, solicited_only) 39 | } 40 | 41 | #[inline(always)] 42 | pub unsafe fn ibv_poll_cq(cq: *mut ibv_cq, num_entries: c_int, wc: *mut ibv_wc) -> c_int { 43 | (*(*cq).context).ops.poll_cq.unwrap_unchecked()(cq, num_entries, wc) 44 | } 45 | 46 | #[inline(always)] 47 | pub unsafe fn ibv_post_send( 48 | qp: *mut ibv_qp, 49 | wr: *mut ibv_send_wr, 50 | bad_wr: *mut *mut ibv_send_wr, 51 | ) -> c_int { 52 | (*(*qp).context).ops.post_send.unwrap_unchecked()(qp, wr, bad_wr) 53 | } 54 | 55 | #[inline(always)] 56 | pub unsafe fn ibv_post_recv( 57 | qp: *mut ibv_qp, 58 | wr: *mut ibv_recv_wr, 59 | bad_wr: *mut *mut ibv_recv_wr, 60 | ) -> c_int { 61 | (*(*qp).context).ops.post_recv.unwrap_unchecked()(qp, wr, bad_wr) 62 | } 63 | 64 | impl ibv_gid { 65 | pub fn as_raw(&self) -> &[u8; 16] { 66 | unsafe { &self.raw } 67 | } 68 | 69 | pub fn as_bits(&self) -> u128 { 70 | u128::from_be_bytes(unsafe { self.raw }) 71 | } 72 | 73 | pub fn as_ipv6(&self) -> std::net::Ipv6Addr { 74 | Ipv6Addr::from_bits(self.as_bits()) 75 | } 76 | 77 | pub fn subnet_prefix(&self) -> u64 { 78 | u64::from_be(unsafe { self.global.subnet_prefix }) 79 | } 80 | 81 | pub fn interface_id(&self) -> u64 { 82 | u64::from_be(unsafe { self.global.interface_id }) 83 | } 84 | 85 | pub fn is_null(&self) -> bool { 86 | self.interface_id() == 0 87 | } 88 | } 89 | 90 | impl std::fmt::Debug for ibv_gid { 91 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 92 | let gid = self 93 | .as_raw() 94 | .chunks_exact(2) 95 | .map(|b| format!("{:02x}{:02x}", b[0], b[1])) 96 | .reduce(|a, b| format!("{}:{}", a, b)) 97 | .unwrap(); 98 | f.write_str(&gid) 99 | } 100 | } 101 | 102 | impl derse::Serialize for ibv_gid { 103 | fn serialize_to(&self, serializer: &mut S) -> derse::Result<()> { 104 | self.as_raw().serialize_to(serializer) 105 | } 106 | } 107 | 108 | impl<'a> derse::Deserialize<'a> for ibv_gid { 109 | fn deserialize_from>(buf: &mut S) -> derse::Result 110 | where 111 | Self: Sized, 112 | { 113 | let mut gid = ibv_gid::default(); 114 | gid.raw = <[u8; 16]>::deserialize_from(buf)?; 115 | Ok(gid) 116 | } 117 | } 118 | 119 | pub const ACCESS_FLAGS: u32 = ibv_access_flags::IBV_ACCESS_LOCAL_WRITE.0 120 | | ibv_access_flags::IBV_ACCESS_REMOTE_WRITE.0 121 | | ibv_access_flags::IBV_ACCESS_REMOTE_READ.0 122 | | ibv_access_flags::IBV_ACCESS_RELAXED_ORDERING.0; 123 | -------------------------------------------------------------------------------- /r2pc-demo/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "r2pc-demo" 3 | version = "0.1.0" 4 | authors.workspace = true 5 | edition.workspace = true 6 | homepage.workspace = true 7 | repository.workspace = true 8 | description.workspace = true 9 | license.workspace = true 10 | 11 | [dependencies] 12 | r2pc = { path = "../r2pc" } 13 | 14 | clap = { version = "4", features = ["derive"] } 15 | derse = "0" 16 | tokio = { version = "1", features = ["full"] } 17 | tracing.workspace = true 18 | tracing-subscriber = "0" 19 | -------------------------------------------------------------------------------- /r2pc-demo/src/bin/client.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use r2pc::{Client, ConnectionPool, Context, Transport}; 3 | use r2pc_demo::{EchoService, GreetService, Request}; 4 | use std::sync::{ 5 | atomic::{AtomicU64, Ordering}, 6 | Arc, 7 | }; 8 | 9 | #[derive(Parser, Debug, Clone)] 10 | #[command(version, about, long_about = None)] 11 | pub struct Args { 12 | /// Listen address. 13 | #[arg(default_value = "127.0.0.1:8000")] 14 | pub addr: std::net::SocketAddr, 15 | 16 | /// Request value. 17 | #[arg(short, long, default_value = "alice")] 18 | pub value: String, 19 | 20 | /// Enable stress testing. 21 | #[arg(long, default_value_t = false)] 22 | pub stress: bool, 23 | 24 | /// Stress testing duration. 25 | #[arg(long, default_value = "60")] 26 | pub secs: u64, 27 | 28 | /// The number of coroutines. 29 | #[arg(long, default_value = "32")] 30 | pub coroutines: usize, 31 | } 32 | 33 | async fn stress_test(args: Args) { 34 | let counter = Arc::new(AtomicU64::new(0)); 35 | let start_time = std::time::Instant::now(); 36 | let pool = Arc::new(ConnectionPool::new(64)); 37 | let tr = Transport::new_sync(pool, args.addr); 38 | let ctx = Context::new(tr); 39 | let mut tasks = vec![]; 40 | for _ in 0..args.coroutines { 41 | let value = Request(args.value.clone()); 42 | let counter = counter.clone(); 43 | let ctx = ctx.clone(); 44 | tasks.push(tokio::spawn(async move { 45 | while std::time::Instant::now() 46 | .duration_since(start_time) 47 | .as_secs() 48 | < args.secs 49 | { 50 | let client = Client::default(); 51 | for _ in 0..4096 { 52 | let rsp = client.echo(&ctx, &value).await; 53 | assert!(rsp.is_ok()); 54 | counter.fetch_add(1, Ordering::AcqRel); 55 | } 56 | } 57 | })); 58 | } 59 | tokio::select! { 60 | _ = async { 61 | for task in tasks { 62 | task.await.unwrap(); 63 | } 64 | } => { 65 | } 66 | _ = async { 67 | loop { 68 | tokio::time::sleep(std::time::Duration::from_secs(1)).await; 69 | tracing::info!("QPS: {}/s", counter.swap(0, Ordering::SeqCst)); 70 | } 71 | } => { 72 | } 73 | } 74 | } 75 | 76 | #[tokio::main] 77 | async fn main() { 78 | tracing_subscriber::fmt() 79 | .with_max_level(tracing::Level::INFO) 80 | .init(); 81 | 82 | let args = Args::parse(); 83 | 84 | if args.stress { 85 | stress_test(args).await; 86 | } else { 87 | let pool = Arc::new(ConnectionPool::new(4)); 88 | let tr = Transport::new_sync(pool, args.addr); 89 | let ctx = Context::new(tr); 90 | let client = Client::default(); 91 | let rsp = client.echo(&ctx, &Request(args.value.clone())).await; 92 | tracing::info!("echo rsp: {:?}", rsp); 93 | 94 | let rsp = client.greet(&ctx, &Request(args.value.clone())).await; 95 | tracing::info!("greet rsp: {:?}", rsp); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /r2pc-demo/src/bin/server.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use r2pc::{Context, Result, Server}; 3 | use r2pc_demo::{EchoService, GreetService, Request}; 4 | use std::sync::{ 5 | atomic::{AtomicU64, Ordering}, 6 | Arc, 7 | }; 8 | 9 | #[derive(Parser, Debug, Clone)] 10 | #[command(version, about, long_about = None)] 11 | pub struct Args { 12 | /// Listen address. 13 | #[arg(default_value = "0.0.0.0:8000")] 14 | pub addr: std::net::SocketAddr, 15 | } 16 | 17 | #[derive(Default)] 18 | struct DemoImpl { 19 | idx: AtomicU64, 20 | } 21 | 22 | impl EchoService for DemoImpl { 23 | async fn echo(&self, _c: &Context, r: &Request) -> Result { 24 | self.idx.fetch_add(1, Ordering::AcqRel); 25 | Ok(r.0.clone()) 26 | } 27 | } 28 | 29 | impl GreetService for DemoImpl { 30 | async fn greet(&self, _c: &Context, r: &Request) -> Result { 31 | let val = self.idx.fetch_add(1, Ordering::AcqRel); 32 | Ok(format!("hello {}({})!", r.0, val)) 33 | } 34 | } 35 | 36 | #[tokio::main] 37 | async fn main() { 38 | tracing_subscriber::fmt() 39 | .with_max_level(tracing::Level::INFO) 40 | .init(); 41 | 42 | let args = Args::parse(); 43 | let mut server = Server::default(); 44 | 45 | let demo = Arc::new(DemoImpl::default()); 46 | server.add_methods(EchoService::rpc_export(demo.clone())); 47 | server.add_methods(GreetService::rpc_export(demo.clone())); 48 | 49 | let server = Arc::new(server); 50 | let (addr, listen_handle) = server.clone().listen(args.addr).await.unwrap(); 51 | tracing::info!( 52 | "Serving {:?} on {}...", 53 | [ 54 | ::NAME, 55 | ::NAME 56 | ], 57 | addr.to_string() 58 | ); 59 | listen_handle.await.unwrap(); 60 | } 61 | -------------------------------------------------------------------------------- /r2pc-demo/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(return_type_notation)] 2 | use derse::{Deserialize, Serialize}; 3 | use r2pc::{Context, Result}; 4 | 5 | #[derive(Serialize, Deserialize, Debug, Clone)] 6 | pub struct Request(pub String); 7 | 8 | #[r2pc::service] 9 | pub trait EchoService { 10 | async fn echo(&self, c: &Context, r: &Request) -> Result; 11 | } 12 | 13 | #[r2pc::service] 14 | pub trait GreetService { 15 | async fn greet(&self, c: &Context, r: &Request) -> Result; 16 | } 17 | -------------------------------------------------------------------------------- /r2pc-macro/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "r2pc-macro" 3 | version = "0.1.0" 4 | authors.workspace = true 5 | edition.workspace = true 6 | homepage.workspace = true 7 | repository.workspace = true 8 | description = "A Rust RPC framework." 9 | license.workspace = true 10 | 11 | [lib] 12 | proc-macro = true 13 | 14 | [dependencies] 15 | proc-macro-crate = "3" 16 | proc-macro2 = "1" 17 | quote = "1" 18 | syn = { version = "2", features = ["full"] } 19 | -------------------------------------------------------------------------------- /r2pc-macro/src/lib.rs: -------------------------------------------------------------------------------- 1 | use proc_macro::TokenStream; 2 | use quote::quote; 3 | use syn::{parse_macro_input, FnArg, ItemTrait, ReturnType, TraitItem}; 4 | 5 | #[proc_macro_attribute] 6 | pub fn service(_attr: TokenStream, input: TokenStream) -> TokenStream { 7 | let input = parse_macro_input!(input as ItemTrait); 8 | 9 | let trait_ident = &input.ident; 10 | let visibility = input.vis; 11 | let trait_name = trait_ident.to_string(); 12 | 13 | let mut send_bounds = vec![]; 14 | let mut invoke_branchs = vec![]; 15 | let mut client_methods = vec![]; 16 | 17 | let krate = get_crate_name(); 18 | 19 | let input_items = input.items; 20 | for item in &input_items { 21 | if let TraitItem::Fn(method) = item { 22 | let inputs = &method.sig.inputs; 23 | if inputs.len() != 3 24 | || !matches!(inputs[0], FnArg::Receiver(_)) 25 | || method.sig.asyncness.is_none() 26 | { 27 | panic!("the function should be in the form `async fn func(&self, ctx: &Context, req: &Req) -> Result`."); 28 | } 29 | 30 | let method_ident = &method.sig.ident; 31 | if *method_ident == "rpc_export" || *method_ident == "rpc_call" { 32 | panic!("Functions cannot be named `rpc_export` or `rpc_call`!"); 33 | } 34 | 35 | let method_name = format!("{trait_name}/{method_ident}"); 36 | 37 | let req_type = if let FnArg::Typed(ty) = &inputs[2] { 38 | ty.ty.clone() 39 | } else { 40 | panic!("third param is not a typed arg."); 41 | }; 42 | 43 | let rsp_type = if let ReturnType::Type(_, ty) = &method.sig.output { 44 | ty.as_ref().clone() 45 | } else { 46 | panic!("return value is not a result type."); 47 | }; 48 | 49 | client_methods.push(quote! { 50 | async fn #method_ident(&self, ctx: &#krate::Context, req: #req_type) -> #rsp_type { 51 | self.rpc_call(ctx, req, #method_name).await 52 | } 53 | }); 54 | 55 | send_bounds.push(quote! { Self::#method_ident(..): Send, }); 56 | invoke_branchs.push(quote! { 57 | let this = self.clone(); 58 | map.insert( 59 | #method_name.into(), 60 | Box::new(move |ctx, meta, bytes| { 61 | let req = meta.deserialize(bytes)?; 62 | let this = this.clone(); 63 | let ctx = ctx.clone(); 64 | tokio::spawn(async move { 65 | let result = this.#method_ident(&ctx, &req).await; 66 | if let Ok(bytes) = meta.serialize(&result) { 67 | let _ = ctx.tr.send(&bytes).await; 68 | } 69 | }); 70 | Ok(()) 71 | }), 72 | ); 73 | }); 74 | } else { 75 | panic!("only function interfaces are allowed to be defined."); 76 | } 77 | } 78 | 79 | quote! { 80 | #visibility trait #trait_ident { 81 | const NAME: &'static str = #trait_name; 82 | 83 | #(#input_items)* 84 | 85 | fn rpc_export( 86 | self: ::std::sync::Arc, 87 | ) -> ::std::collections::HashMap 88 | where 89 | Self: 'static + Send + Sync, 90 | #(#send_bounds)* 91 | { 92 | let mut map = ::std::collections::HashMap::::default(); 93 | #(#invoke_branchs)* 94 | map 95 | } 96 | } 97 | 98 | impl #trait_ident for #krate::Client { 99 | #(#client_methods)* 100 | } 101 | } 102 | .into() 103 | } 104 | 105 | pub(crate) fn get_crate_name() -> proc_macro2::TokenStream { 106 | let found_crate = proc_macro_crate::crate_name("r2pc").unwrap_or_else(|err| { 107 | eprintln!("Warning: {}\n => defaulting to `crate`", err,); 108 | proc_macro_crate::FoundCrate::Itself 109 | }); 110 | 111 | match found_crate { 112 | proc_macro_crate::FoundCrate::Itself => quote! { crate }, 113 | proc_macro_crate::FoundCrate::Name(name) => { 114 | let ident = syn::Ident::new(&name, proc_macro2::Span::call_site()); 115 | quote! { ::#ident } 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /r2pc/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "r2pc" 3 | version = "0.1.0" 4 | authors.workspace = true 5 | edition = "2024" 6 | homepage.workspace = true 7 | repository.workspace = true 8 | description = "A Rust RPC framework." 9 | license.workspace = true 10 | 11 | [dependencies] 12 | r2pc-macro = { version = "0.1.0", path = "../r2pc-macro" } 13 | 14 | derse.workspace = true 15 | thiserror.workspace = true 16 | tracing.workspace = true 17 | tokio = { version = "1", features = ["full"] } 18 | tokio-util = "0" 19 | lockmap = "0" 20 | 21 | [dev-dependencies] 22 | clap = { version = "4", features = ["derive"] } 23 | tracing-subscriber = "0" 24 | -------------------------------------------------------------------------------- /r2pc/examples/r2pc_info.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use r2pc::{Client, ConnectionPool, Context, InfoService, Result, Transport}; 3 | use std::sync::Arc; 4 | 5 | #[derive(Parser, Debug, Clone)] 6 | #[command(version, about, long_about = None)] 7 | pub struct Args { 8 | /// Server address. 9 | #[arg(default_value = "127.0.0.1:8000")] 10 | pub addr: std::net::SocketAddr, 11 | } 12 | 13 | #[tokio::main] 14 | async fn main() -> Result<()> { 15 | tracing_subscriber::fmt() 16 | .with_max_level(tracing::Level::INFO) 17 | .init(); 18 | 19 | let args = Args::parse(); 20 | 21 | let pool = Arc::new(ConnectionPool::new(4)); 22 | let tr = Transport::new_sync(pool, args.addr); 23 | let ctx = Context::new(tr); 24 | let client = Client::default(); 25 | let rsp = client.list_methods(&ctx, &()).await?; 26 | if !rsp.is_empty() { 27 | tracing::info!( 28 | "The address {} provides the following RPC methods: {:#?}", 29 | args.addr, 30 | rsp 31 | ); 32 | } 33 | Ok(()) 34 | } 35 | -------------------------------------------------------------------------------- /r2pc/src/client.rs: -------------------------------------------------------------------------------- 1 | use crate::{Context, Meta}; 2 | use derse::{Deserialize, Serialize}; 3 | use std::time::Duration; 4 | 5 | pub struct Client { 6 | timeout: Duration, 7 | } 8 | 9 | impl Default for Client { 10 | fn default() -> Self { 11 | Self { 12 | timeout: Duration::from_secs(1), 13 | } 14 | } 15 | } 16 | 17 | impl Client { 18 | pub async fn rpc_call( 19 | &self, 20 | ctx: &Context, 21 | req: &Req, 22 | method_name: &str, 23 | ) -> std::result::Result 24 | where 25 | Req: Serialize, 26 | Rsp: for<'c> Deserialize<'c>, 27 | Error: std::error::Error + From + for<'c> Deserialize<'c>, 28 | { 29 | let meta = Meta { 30 | msg_id: Default::default(), 31 | method: method_name.into(), 32 | flags: Default::default(), 33 | }; 34 | let bytes = meta.serialize(req)?; 35 | let bytes = match tokio::time::timeout(self.timeout, ctx.tr.request(&bytes)).await { 36 | Ok(r) => r?, 37 | Err(e) => return Err(crate::Error::Timeout(e.to_string()).into()), 38 | }; 39 | let mut buf = bytes.as_slice(); 40 | let _ = Meta::deserialize_from(&mut buf).map_err(Into::into)?; 41 | std::result::Result::::deserialize_from(&mut buf).map_err(Into::into)? 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /r2pc/src/connection_pool.rs: -------------------------------------------------------------------------------- 1 | use crate::{Error, Result}; 2 | use std::net::SocketAddr; 3 | use tokio::net::TcpStream; 4 | 5 | pub struct ConnectionPool { 6 | max_connection_num: usize, 7 | map: lockmap::LockMap>, 8 | } 9 | 10 | impl ConnectionPool { 11 | pub fn new(max_connection_num: usize) -> Self { 12 | Self { 13 | max_connection_num, 14 | map: Default::default(), 15 | } 16 | } 17 | 18 | pub async fn acquire(&self, addr: SocketAddr) -> Result { 19 | let mut entry = self.map.entry(addr); 20 | if let Some(conns) = entry.get_mut() { 21 | if let Some(conn) = conns.pop() { 22 | return Ok(conn); 23 | } 24 | } 25 | drop(entry); 26 | 27 | self.connect(addr).await 28 | } 29 | 30 | pub fn restore(&self, addr: SocketAddr, stream: TcpStream) { 31 | let mut entry = self.map.entry(addr); 32 | if let Some(conns) = entry.get_mut() { 33 | if conns.len() < self.max_connection_num { 34 | conns.push(stream); 35 | } 36 | } else { 37 | entry.insert(vec![stream]); 38 | } 39 | } 40 | 41 | async fn connect(&self, addr: SocketAddr) -> Result { 42 | match tokio::time::timeout(std::time::Duration::from_secs(1), TcpStream::connect(&addr)) 43 | .await 44 | { 45 | Ok(r) => r.map_err(|e| Error::SocketError(e.to_string())), 46 | Err(e) => Err(Error::SocketError(e.to_string())), 47 | } 48 | } 49 | } 50 | 51 | impl std::fmt::Debug for ConnectionPool { 52 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 53 | f.debug_struct("ConnectionPool") 54 | .field("max_connection_num", &self.max_connection_num) 55 | .finish() 56 | } 57 | } 58 | 59 | #[cfg(test)] 60 | mod tests { 61 | use super::*; 62 | 63 | #[tokio::test] 64 | pub async fn test_connection_pool() { 65 | let listener = std::net::TcpListener::bind("0.0.0.0:0").unwrap(); 66 | let addr = listener.local_addr().unwrap(); 67 | 68 | let pool = ConnectionPool::new(2); 69 | let stream = pool.acquire(addr).await.unwrap(); 70 | pool.restore(addr, stream); 71 | 72 | let stream1 = pool.acquire(addr).await.unwrap(); 73 | let stream2 = pool.acquire(addr).await.unwrap(); 74 | pool.restore(addr, stream1); 75 | pool.restore(addr, stream2); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /r2pc/src/constants.rs: -------------------------------------------------------------------------------- 1 | pub const MSG_HEADER: u32 = u32::from_be_bytes(*b"r2pc"); 2 | pub const MAX_MSG_SIZE: usize = 64 << 20; 3 | -------------------------------------------------------------------------------- /r2pc/src/context.rs: -------------------------------------------------------------------------------- 1 | use super::{Server, Transport}; 2 | use std::sync::Arc; 3 | 4 | #[derive(Clone, Debug)] 5 | pub struct Context { 6 | pub tr: Transport, 7 | pub server: Option>, 8 | } 9 | 10 | impl Context { 11 | pub fn new(tr: Transport) -> Self { 12 | Self { tr, server: None } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /r2pc/src/core_service/info_service.rs: -------------------------------------------------------------------------------- 1 | use crate::{Context, Result, service}; 2 | 3 | #[service] 4 | pub trait InfoService { 5 | async fn list_methods(&self, ctx: &Context, v: &()) -> Result>; 6 | } 7 | 8 | impl InfoService for super::CoreServiceImpl { 9 | async fn list_methods(&self, ctx: &Context, _: &()) -> Result> { 10 | let server = ctx.server.as_ref().unwrap(); 11 | Ok(server.methods.keys().cloned().collect()) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /r2pc/src/core_service/mod.rs: -------------------------------------------------------------------------------- 1 | mod info_service; 2 | 3 | pub use info_service::*; 4 | 5 | pub struct CoreServiceImpl; 6 | -------------------------------------------------------------------------------- /r2pc/src/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(thiserror::Error, derse::Serialize, derse::Deserialize)] 2 | pub enum Error { 3 | #[error("serialization error: {0}")] 4 | DerseError(#[from] derse::Error), 5 | #[error("socket error: {0}")] 6 | SocketError(String), 7 | #[error("timeout: {0}")] 8 | Timeout(String), 9 | #[error("invalid msg: {0}")] 10 | InvalidMsg(String), 11 | } 12 | 13 | impl std::fmt::Debug for Error { 14 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 15 | std::fmt::Display::fmt(self, f) 16 | } 17 | } 18 | 19 | pub type Result = std::result::Result; 20 | -------------------------------------------------------------------------------- /r2pc/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(return_type_notation)] 2 | 3 | mod client; 4 | mod connection_pool; 5 | mod constants; 6 | mod context; 7 | mod core_service; 8 | mod error; 9 | mod meta; 10 | mod server; 11 | mod transport; 12 | 13 | pub use client::Client; 14 | pub use connection_pool::*; 15 | pub use constants::*; 16 | pub use context::*; 17 | pub use core_service::*; 18 | pub use error::{Error, Result}; 19 | pub use meta::*; 20 | pub use server::*; 21 | pub use transport::*; 22 | 23 | pub use r2pc_macro::service; 24 | -------------------------------------------------------------------------------- /r2pc/src/meta.rs: -------------------------------------------------------------------------------- 1 | use crate::{MSG_HEADER, Result}; 2 | use derse::{Deserialize, DownwardBytes, Serialize}; 3 | 4 | #[derive(Serialize, Deserialize, Clone, Debug)] 5 | pub struct Meta { 6 | pub msg_id: u64, 7 | pub method: String, 8 | pub flags: u32, 9 | } 10 | 11 | impl Meta { 12 | pub fn serialize(&self, payload: &P) -> Result { 13 | let mut bytes: DownwardBytes = payload.serialize()?; 14 | self.serialize_to(&mut bytes)?; 15 | let len = bytes.len() as u32; 16 | bytes.prepend(len.to_be_bytes()); 17 | bytes.prepend(MSG_HEADER.to_be_bytes()); 18 | Ok(bytes) 19 | } 20 | 21 | pub fn deserialize<'a, P: Deserialize<'a>>(&self, bytes: &'a [u8]) -> Result

{ 22 | let payload = Deserialize::deserialize(bytes)?; 23 | Ok(payload) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /r2pc/src/server.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use derse::Deserialize; 3 | use std::{collections::HashMap, net::SocketAddr, sync::Arc}; 4 | use tokio::net::TcpStream; 5 | 6 | pub type Method = Box Result<()> + Send + Sync>; 7 | 8 | pub struct Server { 9 | stop_token: tokio_util::sync::CancellationToken, 10 | pub methods: HashMap, 11 | } 12 | 13 | impl Default for Server { 14 | fn default() -> Self { 15 | let mut this = Self { 16 | stop_token: Default::default(), 17 | methods: Default::default(), 18 | }; 19 | 20 | let core_service = Arc::new(CoreServiceImpl); 21 | this.add_methods(InfoService::rpc_export(core_service.clone())); 22 | this 23 | } 24 | } 25 | 26 | impl Server { 27 | pub fn add_methods(&mut self, methods: HashMap) { 28 | self.methods.extend(methods); 29 | } 30 | 31 | pub fn stop(&self) { 32 | self.stop_token.cancel(); 33 | } 34 | 35 | pub async fn listen( 36 | self: Arc, 37 | addr: SocketAddr, 38 | ) -> std::io::Result<(SocketAddr, tokio::task::JoinHandle<()>)> { 39 | let listener = tokio::net::TcpListener::bind(addr).await?; 40 | let listener_addr = listener.local_addr()?; 41 | let stop_token = self.stop_token.clone(); 42 | 43 | let listen_routine = tokio::spawn(async move { 44 | tokio::select! { 45 | _ = stop_token.cancelled() => { 46 | tracing::info!("stop accept loop"); 47 | } 48 | _ = async { 49 | while let Ok((socket, addr)) = listener.accept().await { 50 | let clone = self.clone(); 51 | tokio::spawn(async move { 52 | tracing::info!("socket {addr} established"); 53 | match clone.handle(socket).await { 54 | Ok(_) => tracing::info!("socket {addr} closed"), 55 | Err(err) => tracing::info!("socket {addr} closed with error {err}"), 56 | } 57 | }); 58 | } 59 | } => {} 60 | } 61 | }); 62 | 63 | Ok((listener_addr, listen_routine)) 64 | } 65 | 66 | pub async fn handle(self: Arc, socket: TcpStream) -> Result<()> { 67 | let recv_stream = socket.into_std().unwrap(); 68 | let send_stream = recv_stream.try_clone().unwrap(); 69 | let recv_tr = Transport::new_async(TcpStream::from_std(recv_stream).unwrap()); 70 | let send_tr = Transport::new_async(TcpStream::from_std(send_stream).unwrap()); 71 | 72 | loop { 73 | let bytes = recv_tr.recv().await?; 74 | 75 | let mut buf = bytes.as_slice(); 76 | let meta = Meta::deserialize_from(&mut buf)?; 77 | if let Some(func) = self.methods.get(&meta.method) { 78 | let ctx = Context { 79 | tr: send_tr.clone(), 80 | server: Some(self.clone()), 81 | }; 82 | let _ = func(&ctx, meta, buf); 83 | } 84 | } 85 | } 86 | } 87 | 88 | impl std::fmt::Debug for Server { 89 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 90 | f.debug_struct("Server").finish() 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /r2pc/src/transport.rs: -------------------------------------------------------------------------------- 1 | use crate::*; 2 | use std::{net::SocketAddr, sync::Arc}; 3 | use tokio::{ 4 | io::{AsyncReadExt, AsyncWriteExt}, 5 | sync::Mutex, 6 | }; 7 | 8 | #[derive(Clone, Debug)] 9 | pub enum Transport { 10 | SyncTcpStream { 11 | pool: Arc, 12 | addr: SocketAddr, 13 | }, 14 | AsyncTcpStream(Arc>), 15 | } 16 | 17 | impl Transport { 18 | pub fn new_sync(pool: Arc, addr: SocketAddr) -> Self { 19 | Self::SyncTcpStream { pool, addr } 20 | } 21 | 22 | pub fn new_async(stream: tokio::net::TcpStream) -> Self { 23 | Self::AsyncTcpStream(Arc::new(Mutex::new(stream))) 24 | } 25 | 26 | pub async fn request(&self, bytes: &[u8]) -> Result> { 27 | match self { 28 | Transport::SyncTcpStream { pool, addr } => { 29 | let mut stream = pool.acquire(*addr).await?; 30 | stream 31 | .write_all(bytes) 32 | .await 33 | .map_err(|e| Error::SocketError(e.to_string()))?; 34 | 35 | let header = stream 36 | .read_u64() 37 | .await 38 | .map_err(|e| Error::SocketError(e.to_string()))?; 39 | 40 | if (header >> 32) as u32 != MSG_HEADER { 41 | return Err(Error::InvalidMsg(format!("invalid header: {:08X}", header))); 42 | } 43 | 44 | let len = header as u32 as usize; 45 | if len >= MAX_MSG_SIZE { 46 | return Err(Error::InvalidMsg(format!("msg is too long: {}", len))); 47 | } 48 | 49 | let mut bytes = vec![0u8; len]; 50 | stream 51 | .read_exact(&mut bytes) 52 | .await 53 | .map_err(|e| Error::SocketError(e.to_string()))?; 54 | pool.restore(*addr, stream); 55 | Ok(bytes) 56 | } 57 | Transport::AsyncTcpStream(_) => todo!(), 58 | } 59 | } 60 | 61 | pub async fn send(&self, bytes: &[u8]) -> Result<()> { 62 | match self { 63 | Transport::SyncTcpStream { pool: _, addr: _ } => { 64 | Err(Error::SocketError("invalid op!".into())) 65 | } 66 | Transport::AsyncTcpStream(tcp_stream) => { 67 | let mut socket = tcp_stream.lock().await; 68 | socket 69 | .write_all(bytes) 70 | .await 71 | .map_err(|e| Error::SocketError(e.to_string())) 72 | } 73 | } 74 | } 75 | 76 | pub async fn recv(&self) -> Result> { 77 | match self { 78 | Transport::SyncTcpStream { pool: _, addr: _ } => { 79 | Err(Error::SocketError("invalid op!".into())) 80 | } 81 | Transport::AsyncTcpStream(tcp_stream) => { 82 | let mut socket = tcp_stream.lock().await; 83 | 84 | let header = socket 85 | .read_u64() 86 | .await 87 | .map_err(|e| Error::SocketError(e.to_string()))?; 88 | 89 | if (header >> 32) as u32 != MSG_HEADER { 90 | return Err(Error::InvalidMsg(format!("invalid header: {:08X}", header))); 91 | } 92 | 93 | let len = header as u32 as usize; 94 | if len >= MAX_MSG_SIZE { 95 | return Err(Error::InvalidMsg(format!("msg is too long: {}", len))); 96 | } 97 | 98 | let mut bytes = vec![0u8; len]; 99 | socket 100 | .read_exact(&mut bytes) 101 | .await 102 | .map_err(|e| Error::SocketError(e.to_string()))?; 103 | 104 | drop(socket); 105 | Ok(bytes) 106 | } 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /r2pc/tests/test_concurrent.rs: -------------------------------------------------------------------------------- 1 | #![feature(return_type_notation)] 2 | use derse::{Deserialize, Serialize}; 3 | use r2pc::{Client, ConnectionPool, Context, Result, Server, Transport}; 4 | use std::{ 5 | str::FromStr, 6 | sync::{ 7 | Arc, 8 | atomic::{AtomicUsize, Ordering}, 9 | }, 10 | }; 11 | 12 | #[derive(Debug, Serialize, Deserialize)] 13 | pub struct CallReq {} 14 | 15 | #[derive(Debug, Serialize, Deserialize)] 16 | pub struct CallRsp {} 17 | 18 | #[r2pc::service] 19 | pub trait DemoService { 20 | async fn invoke(&self, c: &Context, r: &CallReq) -> Result; 21 | } 22 | 23 | #[derive(Default)] 24 | struct DemoImpl { 25 | value: AtomicUsize, 26 | } 27 | 28 | impl DemoService for DemoImpl { 29 | async fn invoke(&self, _ctx: &Context, _req: &CallReq) -> Result { 30 | self.value.fetch_add(1, Ordering::SeqCst); 31 | Ok(CallRsp {}) 32 | } 33 | } 34 | 35 | #[tokio::test] 36 | async fn test_concurrent_call() { 37 | let demo = Arc::new(DemoImpl::default()); 38 | let mut server = Server::default(); 39 | server.add_methods(demo.clone().rpc_export()); 40 | let server = Arc::new(server); 41 | let addr = std::net::SocketAddr::from_str("0.0.0.0:0").unwrap(); 42 | 43 | let (addr, listen_handle) = server.clone().listen(addr).await.unwrap(); 44 | let pool = Arc::new(ConnectionPool::new(64)); 45 | let tr = Transport::new_sync(pool, addr); 46 | let ctx = Context::new(tr); 47 | 48 | const N: usize = 32; 49 | const M: usize = 4096; 50 | 51 | let mut tasks = vec![]; 52 | for _ in 0..N { 53 | let ctx = ctx.clone(); 54 | tasks.push(tokio::spawn(async move { 55 | let client = Client::default(); 56 | for _ in 0..M { 57 | let req = CallReq {}; 58 | let rsp = client.invoke(&ctx, &req).await; 59 | assert!(rsp.is_ok()); 60 | } 61 | })); 62 | } 63 | for task in tasks { 64 | task.await.unwrap(); 65 | } 66 | 67 | assert_eq!(demo.value.load(Ordering::Acquire), N * M); 68 | server.stop(); 69 | let _ = listen_handle.await; 70 | } 71 | -------------------------------------------------------------------------------- /r2pc/tests/test_service.rs: -------------------------------------------------------------------------------- 1 | #![feature(return_type_notation)] 2 | use derse::{Deserialize, Serialize}; 3 | use r2pc::{Client, ConnectionPool, Context, Error, Server, Transport}; 4 | use std::{str::FromStr, sync::Arc}; 5 | 6 | #[derive(Debug, Serialize, Deserialize)] 7 | pub struct FooReq { 8 | pub data: String, 9 | } 10 | 11 | #[derive(Debug, Serialize, Deserialize)] 12 | pub struct FooRsp { 13 | pub data: String, 14 | } 15 | 16 | #[derive(Debug, Serialize, Deserialize)] 17 | pub struct BarReq { 18 | pub data: u64, 19 | } 20 | 21 | #[derive(Debug, Serialize, Deserialize)] 22 | pub struct BarRsp { 23 | pub data: u64, 24 | } 25 | 26 | #[derive(thiserror::Error, derse::Serialize, derse::Deserialize)] 27 | #[error("bar error: {0}")] 28 | struct DemoError(pub String); 29 | 30 | impl std::fmt::Debug for DemoError { 31 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 32 | std::fmt::Display::fmt(self, f) 33 | } 34 | } 35 | 36 | impl From for DemoError { 37 | fn from(e: Error) -> Self { 38 | Self(e.to_string()) 39 | } 40 | } 41 | 42 | type DemoResult = std::result::Result; 43 | 44 | #[r2pc::service] 45 | pub trait DemoService { 46 | async fn foo(&self, ctx: &Context, req: &FooReq) -> DemoResult; 47 | async fn bar(&self, ctx: &Context, req: &BarReq) -> DemoResult; 48 | async fn timeout(&self, ctx: &Context, req: &FooReq) -> DemoResult; 49 | } 50 | 51 | struct DemoImpl; 52 | impl DemoService for DemoImpl { 53 | async fn foo(&self, ctx: &Context, req: &FooReq) -> DemoResult { 54 | tracing::info!("foo: ctx: {:?}, req: {:?}", ctx, req); 55 | Ok(FooRsp { 56 | data: req.data.clone(), 57 | }) 58 | } 59 | 60 | async fn bar(&self, ctx: &Context, req: &BarReq) -> DemoResult { 61 | tracing::info!("bar: ctx: {:?}, req: {:?}", ctx, req); 62 | Ok(BarRsp { data: req.data + 1 }) 63 | } 64 | 65 | async fn timeout(&self, _ctx: &Context, req: &FooReq) -> DemoResult { 66 | for _ in 0..10 { 67 | tokio::time::sleep(std::time::Duration::from_secs(1)).await; 68 | } 69 | Ok(FooRsp { 70 | data: req.data.clone(), 71 | }) 72 | } 73 | } 74 | 75 | #[tokio::test] 76 | async fn test_demo_service() { 77 | tracing_subscriber::fmt() 78 | .with_max_level(tracing::Level::INFO) 79 | .init(); 80 | 81 | let demo = Arc::new(DemoImpl); 82 | let mut server = Server::default(); 83 | server.add_methods(demo.rpc_export()); 84 | let server = Arc::new(server); 85 | let addr = std::net::SocketAddr::from_str("0.0.0.0:0").unwrap(); 86 | 87 | let (addr, listen_handle) = server.clone().listen(addr).await.unwrap(); 88 | let pool = Arc::new(ConnectionPool::new(16)); 89 | let tr = Transport::new_sync(pool, addr); 90 | let ctx = Context::new(tr); 91 | 92 | let client = Client::default(); 93 | let req = FooReq { data: "foo".into() }; 94 | let rsp = client.foo(&ctx, &req).await; 95 | match rsp { 96 | Ok(r) => assert_eq!(r.data, "foo"), 97 | Err(e) => assert_eq!(e.to_string(), ""), 98 | } 99 | 100 | let req = BarReq { data: 233 }; 101 | let rsp = client.bar(&ctx, &req).await; 102 | match rsp { 103 | Ok(r) => assert_eq!(r.data, 234), 104 | Err(e) => assert_eq!(e.to_string(), ""), 105 | } 106 | 107 | let req = FooReq { data: "foo".into() }; 108 | let rsp = client.timeout(&ctx, &req).await; 109 | tracing::info!("{rsp:?}"); 110 | assert!(rsp.is_err()); 111 | 112 | server.stop(); 113 | let _ = listen_handle.await; 114 | } 115 | 116 | #[test] 117 | fn test_demo_service_sync() { 118 | let demo = Arc::new(DemoImpl); 119 | let mut server = Server::default(); 120 | server.add_methods(demo.rpc_export()); 121 | let server = Arc::new(server); 122 | let addr = std::net::SocketAddr::from_str("0.0.0.0:0").unwrap(); 123 | 124 | let runtime = tokio::runtime::Builder::new_multi_thread() 125 | .worker_threads(1) 126 | .enable_all() 127 | .build() 128 | .unwrap(); 129 | let (addr, listen_handle) = runtime.block_on(server.clone().listen(addr)).unwrap(); 130 | 131 | let pool = Arc::new(ConnectionPool::new(16)); 132 | let tr = Transport::new_sync(pool, addr); 133 | let ctx = Context::new(tr); 134 | 135 | let req = FooReq { data: "foo".into() }; 136 | let current = tokio::runtime::Builder::new_current_thread() 137 | .enable_all() 138 | .build() 139 | .unwrap(); 140 | let client = Client::default(); 141 | let rsp = current.block_on(client.foo(&ctx, &req)); 142 | match rsp { 143 | Ok(r) => assert_eq!(r.data, "foo"), 144 | Err(e) => assert_eq!(e.to_string(), ""), 145 | } 146 | 147 | server.stop(); 148 | let _ = runtime.block_on(listen_handle); 149 | } 150 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "nightly" 3 | --------------------------------------------------------------------------------