├── .github └── workflows │ ├── pr.yaml │ └── release.yaml ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md ├── example-messagepack ├── Cargo.toml └── src │ ├── client.rs │ ├── messages.rs │ └── server.rs ├── example-proto ├── Cargo.toml └── src │ ├── client.rs │ ├── messages.rs │ └── server.rs ├── example-telnet ├── Cargo.toml └── src │ └── main.rs ├── protosocket-connection ├── Cargo.toml ├── README.md └── src │ ├── connection.rs │ ├── lib.rs │ ├── serde.rs │ └── types.rs ├── protosocket-messagepack ├── Cargo.toml └── src │ └── lib.rs ├── protosocket-prost ├── Cargo.toml └── src │ ├── error.rs │ ├── lib.rs │ ├── prost_client_registry.rs │ ├── prost_serializer.rs │ └── prost_socket.rs ├── protosocket-rpc ├── Cargo.toml ├── README.md └── src │ ├── client │ ├── configuration.rs │ ├── mod.rs │ ├── reactor │ │ ├── completion_reactor.rs │ │ ├── completion_registry.rs │ │ ├── completion_streaming.rs │ │ ├── completion_unary.rs │ │ ├── mod.rs │ │ └── rpc_drop_guard.rs │ └── rpc_client.rs │ ├── error.rs │ ├── lib.rs │ ├── message.rs │ └── server │ ├── abortable.rs │ ├── connection_server.rs │ ├── mod.rs │ ├── rpc_submitter.rs │ ├── server_traits.rs │ └── socket_server.rs └── protosocket-server ├── Cargo.toml └── src ├── connection_server.rs ├── error.rs └── lib.rs /.github/workflows/pr.yaml: -------------------------------------------------------------------------------- 1 | name: PR 2 | 3 | on: 4 | pull_request: 5 | branches: [ main ] 6 | 7 | env: 8 | CARGO_TERM_COLOR: always 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: dtolnay/rust-toolchain@stable 16 | with: 17 | toolchain: stable 18 | components: rustfmt, clippy 19 | - uses: Swatinem/rust-cache@v2 20 | 21 | - name: Rustfmt 22 | run: cargo fmt -- --check 23 | - name: Clippy 24 | run: | 25 | cargo --version 26 | cargo clippy --version 27 | cargo clippy --all-targets --all-features -- -D warnings -W clippy::unwrap_used 28 | - name: Run tests 29 | run: cargo test --verbose 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - v0.* 7 | - v1.* 8 | - v2.* 9 | - v3.* 10 | - v4.* 11 | - v5.* 12 | 13 | env: 14 | CARGO_TERM_COLOR: always 15 | 16 | jobs: 17 | build: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v4 21 | - uses: dtolnay/rust-toolchain@stable 22 | with: 23 | toolchain: stable 24 | - uses: katyo/publish-crates@v2 25 | with: 26 | registry-token: ${{ secrets.CARGO_REGISTRY_TOKEN }} 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cross.toml 3 | 4 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 4 4 | 5 | [[package]] 6 | name = "addr2line" 7 | version = "0.22.0" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" 10 | dependencies = [ 11 | "gimli", 12 | ] 13 | 14 | [[package]] 15 | name = "adler" 16 | version = "1.0.2" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" 19 | 20 | [[package]] 21 | name = "aho-corasick" 22 | version = "1.1.3" 23 | source = "registry+https://github.com/rust-lang/crates.io-index" 24 | checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" 25 | dependencies = [ 26 | "memchr", 27 | ] 28 | 29 | [[package]] 30 | name = "anstream" 31 | version = "0.6.15" 32 | source = "registry+https://github.com/rust-lang/crates.io-index" 33 | checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" 34 | dependencies = [ 35 | "anstyle", 36 | "anstyle-parse", 37 | "anstyle-query", 38 | "anstyle-wincon", 39 | "colorchoice", 40 | "is_terminal_polyfill", 41 | "utf8parse", 42 | ] 43 | 44 | [[package]] 45 | name = "anstyle" 46 | version = "1.0.8" 47 | source = "registry+https://github.com/rust-lang/crates.io-index" 48 | checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" 49 | 50 | [[package]] 51 | name = "anstyle-parse" 52 | version = "0.2.5" 53 | source = "registry+https://github.com/rust-lang/crates.io-index" 54 | checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" 55 | dependencies = [ 56 | "utf8parse", 57 | ] 58 | 59 | [[package]] 60 | name = "anstyle-query" 61 | version = "1.1.1" 62 | source = "registry+https://github.com/rust-lang/crates.io-index" 63 | checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" 64 | dependencies = [ 65 | "windows-sys 0.52.0", 66 | ] 67 | 68 | [[package]] 69 | name = "anstyle-wincon" 70 | version = "3.0.4" 71 | source = "registry+https://github.com/rust-lang/crates.io-index" 72 | checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" 73 | dependencies = [ 74 | "anstyle", 75 | "windows-sys 0.52.0", 76 | ] 77 | 78 | [[package]] 79 | name = "anyhow" 80 | version = "1.0.86" 81 | source = "registry+https://github.com/rust-lang/crates.io-index" 82 | checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" 83 | 84 | [[package]] 85 | name = "atomic-wait" 86 | version = "1.1.0" 87 | source = "registry+https://github.com/rust-lang/crates.io-index" 88 | checksum = "a55b94919229f2c42292fd71ffa4b75e83193bffdd77b1e858cd55fd2d0b0ea8" 89 | dependencies = [ 90 | "libc", 91 | "windows-sys 0.42.0", 92 | ] 93 | 94 | [[package]] 95 | name = "autocfg" 96 | version = "1.3.0" 97 | source = "registry+https://github.com/rust-lang/crates.io-index" 98 | checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" 99 | 100 | [[package]] 101 | name = "backtrace" 102 | version = "0.3.73" 103 | source = "registry+https://github.com/rust-lang/crates.io-index" 104 | checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" 105 | dependencies = [ 106 | "addr2line", 107 | "cc", 108 | "cfg-if", 109 | "libc", 110 | "miniz_oxide", 111 | "object", 112 | "rustc-demangle", 113 | ] 114 | 115 | [[package]] 116 | name = "bitflags" 117 | version = "2.6.0" 118 | source = "registry+https://github.com/rust-lang/crates.io-index" 119 | checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" 120 | 121 | [[package]] 122 | name = "byteorder" 123 | version = "1.5.0" 124 | source = "registry+https://github.com/rust-lang/crates.io-index" 125 | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 126 | 127 | [[package]] 128 | name = "bytes" 129 | version = "1.6.1" 130 | source = "registry+https://github.com/rust-lang/crates.io-index" 131 | checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" 132 | 133 | [[package]] 134 | name = "cc" 135 | version = "1.1.6" 136 | source = "registry+https://github.com/rust-lang/crates.io-index" 137 | checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" 138 | 139 | [[package]] 140 | name = "cfg-if" 141 | version = "1.0.0" 142 | source = "registry+https://github.com/rust-lang/crates.io-index" 143 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 144 | 145 | [[package]] 146 | name = "colorchoice" 147 | version = "1.0.2" 148 | source = "registry+https://github.com/rust-lang/crates.io-index" 149 | checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" 150 | 151 | [[package]] 152 | name = "either" 153 | version = "1.13.0" 154 | source = "registry+https://github.com/rust-lang/crates.io-index" 155 | checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" 156 | 157 | [[package]] 158 | name = "env_filter" 159 | version = "0.1.2" 160 | source = "registry+https://github.com/rust-lang/crates.io-index" 161 | checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" 162 | dependencies = [ 163 | "log", 164 | "regex", 165 | ] 166 | 167 | [[package]] 168 | name = "env_logger" 169 | version = "0.11.5" 170 | source = "registry+https://github.com/rust-lang/crates.io-index" 171 | checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" 172 | dependencies = [ 173 | "anstream", 174 | "anstyle", 175 | "env_filter", 176 | "humantime", 177 | "log", 178 | ] 179 | 180 | [[package]] 181 | name = "example" 182 | version = "0.1.0" 183 | dependencies = [ 184 | "bytes", 185 | "env_logger", 186 | "futures", 187 | "log", 188 | "protosocket", 189 | "protosocket-server", 190 | "tokio", 191 | ] 192 | 193 | [[package]] 194 | name = "example-messagepack" 195 | version = "0.1.0" 196 | dependencies = [ 197 | "bytes", 198 | "env_logger", 199 | "futures", 200 | "histogram", 201 | "log", 202 | "protosocket", 203 | "protosocket-messagepack", 204 | "protosocket-rpc", 205 | "serde", 206 | "tokio", 207 | ] 208 | 209 | [[package]] 210 | name = "example-proto" 211 | version = "0.1.0" 212 | dependencies = [ 213 | "bytes", 214 | "env_logger", 215 | "futures", 216 | "histogram", 217 | "log", 218 | "prost", 219 | "protosocket", 220 | "protosocket-prost", 221 | "protosocket-rpc", 222 | "protosocket-server", 223 | "tokio", 224 | ] 225 | 226 | [[package]] 227 | name = "futures" 228 | version = "0.3.30" 229 | source = "registry+https://github.com/rust-lang/crates.io-index" 230 | checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" 231 | dependencies = [ 232 | "futures-channel", 233 | "futures-core", 234 | "futures-executor", 235 | "futures-io", 236 | "futures-sink", 237 | "futures-task", 238 | "futures-util", 239 | ] 240 | 241 | [[package]] 242 | name = "futures-channel" 243 | version = "0.3.30" 244 | source = "registry+https://github.com/rust-lang/crates.io-index" 245 | checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" 246 | dependencies = [ 247 | "futures-core", 248 | "futures-sink", 249 | ] 250 | 251 | [[package]] 252 | name = "futures-core" 253 | version = "0.3.30" 254 | source = "registry+https://github.com/rust-lang/crates.io-index" 255 | checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" 256 | 257 | [[package]] 258 | name = "futures-executor" 259 | version = "0.3.30" 260 | source = "registry+https://github.com/rust-lang/crates.io-index" 261 | checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" 262 | dependencies = [ 263 | "futures-core", 264 | "futures-task", 265 | "futures-util", 266 | ] 267 | 268 | [[package]] 269 | name = "futures-io" 270 | version = "0.3.30" 271 | source = "registry+https://github.com/rust-lang/crates.io-index" 272 | checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" 273 | 274 | [[package]] 275 | name = "futures-macro" 276 | version = "0.3.30" 277 | source = "registry+https://github.com/rust-lang/crates.io-index" 278 | checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" 279 | dependencies = [ 280 | "proc-macro2", 281 | "quote", 282 | "syn", 283 | ] 284 | 285 | [[package]] 286 | name = "futures-sink" 287 | version = "0.3.30" 288 | source = "registry+https://github.com/rust-lang/crates.io-index" 289 | checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" 290 | 291 | [[package]] 292 | name = "futures-task" 293 | version = "0.3.30" 294 | source = "registry+https://github.com/rust-lang/crates.io-index" 295 | checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" 296 | 297 | [[package]] 298 | name = "futures-util" 299 | version = "0.3.30" 300 | source = "registry+https://github.com/rust-lang/crates.io-index" 301 | checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" 302 | dependencies = [ 303 | "futures-channel", 304 | "futures-core", 305 | "futures-io", 306 | "futures-macro", 307 | "futures-sink", 308 | "futures-task", 309 | "memchr", 310 | "pin-project-lite", 311 | "pin-utils", 312 | "slab", 313 | ] 314 | 315 | [[package]] 316 | name = "gimli" 317 | version = "0.29.0" 318 | source = "registry+https://github.com/rust-lang/crates.io-index" 319 | checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" 320 | 321 | [[package]] 322 | name = "hermit-abi" 323 | version = "0.3.9" 324 | source = "registry+https://github.com/rust-lang/crates.io-index" 325 | checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" 326 | 327 | [[package]] 328 | name = "histogram" 329 | version = "0.11.0" 330 | source = "registry+https://github.com/rust-lang/crates.io-index" 331 | checksum = "b62b8d85713ddc62e5e78db13bf9f9305610d0419276faa845076a68b7165872" 332 | dependencies = [ 333 | "thiserror", 334 | ] 335 | 336 | [[package]] 337 | name = "humantime" 338 | version = "2.1.0" 339 | source = "registry+https://github.com/rust-lang/crates.io-index" 340 | checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" 341 | 342 | [[package]] 343 | name = "is_terminal_polyfill" 344 | version = "1.70.1" 345 | source = "registry+https://github.com/rust-lang/crates.io-index" 346 | checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" 347 | 348 | [[package]] 349 | name = "itertools" 350 | version = "0.13.0" 351 | source = "registry+https://github.com/rust-lang/crates.io-index" 352 | checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" 353 | dependencies = [ 354 | "either", 355 | ] 356 | 357 | [[package]] 358 | name = "k-lock" 359 | version = "0.2.5" 360 | source = "registry+https://github.com/rust-lang/crates.io-index" 361 | checksum = "ab228fb4852ec5d997306dc0d652235b0de0595dc1f0a7a2bd96b03730b96474" 362 | dependencies = [ 363 | "atomic-wait", 364 | "readme-rustdocifier", 365 | ] 366 | 367 | [[package]] 368 | name = "libc" 369 | version = "0.2.155" 370 | source = "registry+https://github.com/rust-lang/crates.io-index" 371 | checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" 372 | 373 | [[package]] 374 | name = "lock_api" 375 | version = "0.4.12" 376 | source = "registry+https://github.com/rust-lang/crates.io-index" 377 | checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" 378 | dependencies = [ 379 | "autocfg", 380 | "scopeguard", 381 | ] 382 | 383 | [[package]] 384 | name = "log" 385 | version = "0.4.22" 386 | source = "registry+https://github.com/rust-lang/crates.io-index" 387 | checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" 388 | 389 | [[package]] 390 | name = "memchr" 391 | version = "2.7.4" 392 | source = "registry+https://github.com/rust-lang/crates.io-index" 393 | checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" 394 | 395 | [[package]] 396 | name = "miniz_oxide" 397 | version = "0.7.4" 398 | source = "registry+https://github.com/rust-lang/crates.io-index" 399 | checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" 400 | dependencies = [ 401 | "adler", 402 | ] 403 | 404 | [[package]] 405 | name = "mio" 406 | version = "1.0.1" 407 | source = "registry+https://github.com/rust-lang/crates.io-index" 408 | checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" 409 | dependencies = [ 410 | "hermit-abi", 411 | "libc", 412 | "wasi", 413 | "windows-sys 0.52.0", 414 | ] 415 | 416 | [[package]] 417 | name = "num-traits" 418 | version = "0.2.19" 419 | source = "registry+https://github.com/rust-lang/crates.io-index" 420 | checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" 421 | dependencies = [ 422 | "autocfg", 423 | ] 424 | 425 | [[package]] 426 | name = "object" 427 | version = "0.36.2" 428 | source = "registry+https://github.com/rust-lang/crates.io-index" 429 | checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" 430 | dependencies = [ 431 | "memchr", 432 | ] 433 | 434 | [[package]] 435 | name = "parking_lot" 436 | version = "0.12.3" 437 | source = "registry+https://github.com/rust-lang/crates.io-index" 438 | checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" 439 | dependencies = [ 440 | "lock_api", 441 | "parking_lot_core", 442 | ] 443 | 444 | [[package]] 445 | name = "parking_lot_core" 446 | version = "0.9.10" 447 | source = "registry+https://github.com/rust-lang/crates.io-index" 448 | checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" 449 | dependencies = [ 450 | "cfg-if", 451 | "libc", 452 | "redox_syscall", 453 | "smallvec", 454 | "windows-targets", 455 | ] 456 | 457 | [[package]] 458 | name = "paste" 459 | version = "1.0.15" 460 | source = "registry+https://github.com/rust-lang/crates.io-index" 461 | checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" 462 | 463 | [[package]] 464 | name = "pin-project-lite" 465 | version = "0.2.14" 466 | source = "registry+https://github.com/rust-lang/crates.io-index" 467 | checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" 468 | 469 | [[package]] 470 | name = "pin-utils" 471 | version = "0.1.0" 472 | source = "registry+https://github.com/rust-lang/crates.io-index" 473 | checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" 474 | 475 | [[package]] 476 | name = "proc-macro2" 477 | version = "1.0.86" 478 | source = "registry+https://github.com/rust-lang/crates.io-index" 479 | checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" 480 | dependencies = [ 481 | "unicode-ident", 482 | ] 483 | 484 | [[package]] 485 | name = "prost" 486 | version = "0.13.1" 487 | source = "registry+https://github.com/rust-lang/crates.io-index" 488 | checksum = "e13db3d3fde688c61e2446b4d843bc27a7e8af269a69440c0308021dc92333cc" 489 | dependencies = [ 490 | "bytes", 491 | "prost-derive", 492 | ] 493 | 494 | [[package]] 495 | name = "prost-derive" 496 | version = "0.13.1" 497 | source = "registry+https://github.com/rust-lang/crates.io-index" 498 | checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca" 499 | dependencies = [ 500 | "anyhow", 501 | "itertools", 502 | "proc-macro2", 503 | "quote", 504 | "syn", 505 | ] 506 | 507 | [[package]] 508 | name = "protosocket" 509 | version = "0.8.0" 510 | dependencies = [ 511 | "bytes", 512 | "futures", 513 | "log", 514 | "serde", 515 | "thiserror", 516 | "tokio", 517 | "tokio-util", 518 | ] 519 | 520 | [[package]] 521 | name = "protosocket-messagepack" 522 | version = "0.8.0" 523 | dependencies = [ 524 | "bytes", 525 | "log", 526 | "protosocket", 527 | "rmp", 528 | "rmp-serde", 529 | "serde", 530 | ] 531 | 532 | [[package]] 533 | name = "protosocket-prost" 534 | version = "0.8.0" 535 | dependencies = [ 536 | "bytes", 537 | "log", 538 | "prost", 539 | "protosocket", 540 | "thiserror", 541 | "tokio", 542 | ] 543 | 544 | [[package]] 545 | name = "protosocket-rpc" 546 | version = "0.8.0" 547 | dependencies = [ 548 | "futures", 549 | "k-lock", 550 | "log", 551 | "prost", 552 | "protosocket", 553 | "thiserror", 554 | "tokio", 555 | "tokio-util", 556 | ] 557 | 558 | [[package]] 559 | name = "protosocket-server" 560 | version = "0.8.0" 561 | dependencies = [ 562 | "bytes", 563 | "futures", 564 | "log", 565 | "protosocket", 566 | "thiserror", 567 | "tokio", 568 | ] 569 | 570 | [[package]] 571 | name = "quote" 572 | version = "1.0.36" 573 | source = "registry+https://github.com/rust-lang/crates.io-index" 574 | checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" 575 | dependencies = [ 576 | "proc-macro2", 577 | ] 578 | 579 | [[package]] 580 | name = "readme-rustdocifier" 581 | version = "0.1.1" 582 | source = "registry+https://github.com/rust-lang/crates.io-index" 583 | checksum = "08ad765b21a08b1a8e5cdce052719188a23772bcbefb3c439f0baaf62c56ceac" 584 | 585 | [[package]] 586 | name = "redox_syscall" 587 | version = "0.5.3" 588 | source = "registry+https://github.com/rust-lang/crates.io-index" 589 | checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" 590 | dependencies = [ 591 | "bitflags", 592 | ] 593 | 594 | [[package]] 595 | name = "regex" 596 | version = "1.10.5" 597 | source = "registry+https://github.com/rust-lang/crates.io-index" 598 | checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" 599 | dependencies = [ 600 | "aho-corasick", 601 | "memchr", 602 | "regex-automata", 603 | "regex-syntax", 604 | ] 605 | 606 | [[package]] 607 | name = "regex-automata" 608 | version = "0.4.7" 609 | source = "registry+https://github.com/rust-lang/crates.io-index" 610 | checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" 611 | dependencies = [ 612 | "aho-corasick", 613 | "memchr", 614 | "regex-syntax", 615 | ] 616 | 617 | [[package]] 618 | name = "regex-syntax" 619 | version = "0.8.4" 620 | source = "registry+https://github.com/rust-lang/crates.io-index" 621 | checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" 622 | 623 | [[package]] 624 | name = "rmp" 625 | version = "0.8.14" 626 | source = "registry+https://github.com/rust-lang/crates.io-index" 627 | checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" 628 | dependencies = [ 629 | "byteorder", 630 | "num-traits", 631 | "paste", 632 | ] 633 | 634 | [[package]] 635 | name = "rmp-serde" 636 | version = "1.3.0" 637 | source = "registry+https://github.com/rust-lang/crates.io-index" 638 | checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" 639 | dependencies = [ 640 | "byteorder", 641 | "rmp", 642 | "serde", 643 | ] 644 | 645 | [[package]] 646 | name = "rustc-demangle" 647 | version = "0.1.24" 648 | source = "registry+https://github.com/rust-lang/crates.io-index" 649 | checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" 650 | 651 | [[package]] 652 | name = "scopeguard" 653 | version = "1.2.0" 654 | source = "registry+https://github.com/rust-lang/crates.io-index" 655 | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" 656 | 657 | [[package]] 658 | name = "serde" 659 | version = "1.0.210" 660 | source = "registry+https://github.com/rust-lang/crates.io-index" 661 | checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" 662 | dependencies = [ 663 | "serde_derive", 664 | ] 665 | 666 | [[package]] 667 | name = "serde_derive" 668 | version = "1.0.210" 669 | source = "registry+https://github.com/rust-lang/crates.io-index" 670 | checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" 671 | dependencies = [ 672 | "proc-macro2", 673 | "quote", 674 | "syn", 675 | ] 676 | 677 | [[package]] 678 | name = "signal-hook-registry" 679 | version = "1.4.2" 680 | source = "registry+https://github.com/rust-lang/crates.io-index" 681 | checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" 682 | dependencies = [ 683 | "libc", 684 | ] 685 | 686 | [[package]] 687 | name = "slab" 688 | version = "0.4.9" 689 | source = "registry+https://github.com/rust-lang/crates.io-index" 690 | checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" 691 | dependencies = [ 692 | "autocfg", 693 | ] 694 | 695 | [[package]] 696 | name = "smallvec" 697 | version = "1.13.2" 698 | source = "registry+https://github.com/rust-lang/crates.io-index" 699 | checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" 700 | 701 | [[package]] 702 | name = "socket2" 703 | version = "0.5.7" 704 | source = "registry+https://github.com/rust-lang/crates.io-index" 705 | checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" 706 | dependencies = [ 707 | "libc", 708 | "windows-sys 0.52.0", 709 | ] 710 | 711 | [[package]] 712 | name = "syn" 713 | version = "2.0.72" 714 | source = "registry+https://github.com/rust-lang/crates.io-index" 715 | checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" 716 | dependencies = [ 717 | "proc-macro2", 718 | "quote", 719 | "unicode-ident", 720 | ] 721 | 722 | [[package]] 723 | name = "thiserror" 724 | version = "1.0.63" 725 | source = "registry+https://github.com/rust-lang/crates.io-index" 726 | checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" 727 | dependencies = [ 728 | "thiserror-impl", 729 | ] 730 | 731 | [[package]] 732 | name = "thiserror-impl" 733 | version = "1.0.63" 734 | source = "registry+https://github.com/rust-lang/crates.io-index" 735 | checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" 736 | dependencies = [ 737 | "proc-macro2", 738 | "quote", 739 | "syn", 740 | ] 741 | 742 | [[package]] 743 | name = "tokio" 744 | version = "1.39.1" 745 | source = "registry+https://github.com/rust-lang/crates.io-index" 746 | checksum = "d040ac2b29ab03b09d4129c2f5bbd012a3ac2f79d38ff506a4bf8dd34b0eac8a" 747 | dependencies = [ 748 | "backtrace", 749 | "bytes", 750 | "libc", 751 | "mio", 752 | "parking_lot", 753 | "pin-project-lite", 754 | "signal-hook-registry", 755 | "socket2", 756 | "tokio-macros", 757 | "windows-sys 0.52.0", 758 | ] 759 | 760 | [[package]] 761 | name = "tokio-macros" 762 | version = "2.4.0" 763 | source = "registry+https://github.com/rust-lang/crates.io-index" 764 | checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" 765 | dependencies = [ 766 | "proc-macro2", 767 | "quote", 768 | "syn", 769 | ] 770 | 771 | [[package]] 772 | name = "tokio-util" 773 | version = "0.7.11" 774 | source = "registry+https://github.com/rust-lang/crates.io-index" 775 | checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" 776 | dependencies = [ 777 | "bytes", 778 | "futures-core", 779 | "futures-sink", 780 | "pin-project-lite", 781 | "tokio", 782 | ] 783 | 784 | [[package]] 785 | name = "unicode-ident" 786 | version = "1.0.12" 787 | source = "registry+https://github.com/rust-lang/crates.io-index" 788 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 789 | 790 | [[package]] 791 | name = "utf8parse" 792 | version = "0.2.2" 793 | source = "registry+https://github.com/rust-lang/crates.io-index" 794 | checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" 795 | 796 | [[package]] 797 | name = "wasi" 798 | version = "0.11.0+wasi-snapshot-preview1" 799 | source = "registry+https://github.com/rust-lang/crates.io-index" 800 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 801 | 802 | [[package]] 803 | name = "windows-sys" 804 | version = "0.42.0" 805 | source = "registry+https://github.com/rust-lang/crates.io-index" 806 | checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" 807 | dependencies = [ 808 | "windows_aarch64_gnullvm 0.42.2", 809 | "windows_aarch64_msvc 0.42.2", 810 | "windows_i686_gnu 0.42.2", 811 | "windows_i686_msvc 0.42.2", 812 | "windows_x86_64_gnu 0.42.2", 813 | "windows_x86_64_gnullvm 0.42.2", 814 | "windows_x86_64_msvc 0.42.2", 815 | ] 816 | 817 | [[package]] 818 | name = "windows-sys" 819 | version = "0.52.0" 820 | source = "registry+https://github.com/rust-lang/crates.io-index" 821 | checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" 822 | dependencies = [ 823 | "windows-targets", 824 | ] 825 | 826 | [[package]] 827 | name = "windows-targets" 828 | version = "0.52.6" 829 | source = "registry+https://github.com/rust-lang/crates.io-index" 830 | checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" 831 | dependencies = [ 832 | "windows_aarch64_gnullvm 0.52.6", 833 | "windows_aarch64_msvc 0.52.6", 834 | "windows_i686_gnu 0.52.6", 835 | "windows_i686_gnullvm", 836 | "windows_i686_msvc 0.52.6", 837 | "windows_x86_64_gnu 0.52.6", 838 | "windows_x86_64_gnullvm 0.52.6", 839 | "windows_x86_64_msvc 0.52.6", 840 | ] 841 | 842 | [[package]] 843 | name = "windows_aarch64_gnullvm" 844 | version = "0.42.2" 845 | source = "registry+https://github.com/rust-lang/crates.io-index" 846 | checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" 847 | 848 | [[package]] 849 | name = "windows_aarch64_gnullvm" 850 | version = "0.52.6" 851 | source = "registry+https://github.com/rust-lang/crates.io-index" 852 | checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" 853 | 854 | [[package]] 855 | name = "windows_aarch64_msvc" 856 | version = "0.42.2" 857 | source = "registry+https://github.com/rust-lang/crates.io-index" 858 | checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" 859 | 860 | [[package]] 861 | name = "windows_aarch64_msvc" 862 | version = "0.52.6" 863 | source = "registry+https://github.com/rust-lang/crates.io-index" 864 | checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" 865 | 866 | [[package]] 867 | name = "windows_i686_gnu" 868 | version = "0.42.2" 869 | source = "registry+https://github.com/rust-lang/crates.io-index" 870 | checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" 871 | 872 | [[package]] 873 | name = "windows_i686_gnu" 874 | version = "0.52.6" 875 | source = "registry+https://github.com/rust-lang/crates.io-index" 876 | checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" 877 | 878 | [[package]] 879 | name = "windows_i686_gnullvm" 880 | version = "0.52.6" 881 | source = "registry+https://github.com/rust-lang/crates.io-index" 882 | checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" 883 | 884 | [[package]] 885 | name = "windows_i686_msvc" 886 | version = "0.42.2" 887 | source = "registry+https://github.com/rust-lang/crates.io-index" 888 | checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" 889 | 890 | [[package]] 891 | name = "windows_i686_msvc" 892 | version = "0.52.6" 893 | source = "registry+https://github.com/rust-lang/crates.io-index" 894 | checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" 895 | 896 | [[package]] 897 | name = "windows_x86_64_gnu" 898 | version = "0.42.2" 899 | source = "registry+https://github.com/rust-lang/crates.io-index" 900 | checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" 901 | 902 | [[package]] 903 | name = "windows_x86_64_gnu" 904 | version = "0.52.6" 905 | source = "registry+https://github.com/rust-lang/crates.io-index" 906 | checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" 907 | 908 | [[package]] 909 | name = "windows_x86_64_gnullvm" 910 | version = "0.42.2" 911 | source = "registry+https://github.com/rust-lang/crates.io-index" 912 | checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" 913 | 914 | [[package]] 915 | name = "windows_x86_64_gnullvm" 916 | version = "0.52.6" 917 | source = "registry+https://github.com/rust-lang/crates.io-index" 918 | checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" 919 | 920 | [[package]] 921 | name = "windows_x86_64_msvc" 922 | version = "0.42.2" 923 | source = "registry+https://github.com/rust-lang/crates.io-index" 924 | checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" 925 | 926 | [[package]] 927 | name = "windows_x86_64_msvc" 928 | version = "0.52.6" 929 | source = "registry+https://github.com/rust-lang/crates.io-index" 930 | checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" 931 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | 4 | members = [ 5 | "example-messagepack", 6 | "example-proto", 7 | "example-telnet", 8 | "protosocket-connection", 9 | "protosocket-prost", 10 | "protosocket-rpc", 11 | "protosocket-messagepack", 12 | "protosocket-server", 13 | ] 14 | 15 | [workspace.package] 16 | version = "0.8.0" 17 | authors = ["kvc0"] 18 | repository = "https://github.com/kvc0/protosocket" 19 | edition = "2021" 20 | license = "Apache-2.0" 21 | readme = "README.md" 22 | keywords = ["tcp", "protobuf", "service", "performance"] 23 | categories = ["web-programming"] 24 | 25 | 26 | 27 | [workspace.dependencies] 28 | protosocket = { path = "protosocket-connection", version = "0" } 29 | protosocket-messagepack = { path = "protosocket-messagepack" } 30 | protosocket-rpc = { path = "protosocket-rpc" } 31 | protosocket-server = { path = "protosocket-server" } 32 | protosocket-prost = { path = "protosocket-prost" } 33 | 34 | bytes = { version = "1.6" } 35 | env_logger = { version = "0.11" } 36 | futures = { version = "0.3" } 37 | k-lock = { version = "0.2" } 38 | log = { version = "0.4" } 39 | prost = { version = "0.13" } 40 | rmp = { version = "0.8" } 41 | rmp-serde = { version = "1.3" } 42 | serde = { version = "1.0" } 43 | thiserror = { version = "1.0" } 44 | tokio = { version = "1.39", features = ["net", "rt"] } 45 | tokio-util = { version = "0.7" } 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # protosocket 2 | Message-oriented, low-abstraction tcp streams. 3 | 4 | A protosocket is a non-blocking, bidirectional, message streaming connection. 5 | Providing a serializer and deserializer for your messages, you can stream to 6 | and from tcp servers. 7 | 8 | There is no wrapper encoding - no HTTP, no gRPC, no websockets. You depend on 9 | TCP and your serialization strategy. 10 | 11 | Dependencies are trim; `tokio` is the main hard dependency. If you use protocol 12 | buffers, you will also depend on `prost`. There's no extra underlying framework. 13 | 14 | Protosockets avoid too many opinions - you have (get?) to choose your own 15 | message ordering and concurrency semantics. You can make an implicitly ordered 16 | stream, or a non-blocking out-of-order stream, or anything in between. 17 | 18 | Tools to facilitate protocol buffers are provided in [`protosocket-prost`](./protosocket-prost/). 19 | 20 | You can write an RPC client/server with [`protosocket-rpc`](./protosocket-rpc/). 21 | 22 | You can see an example of protocol buffers RPC in [`example-proto`](./example-proto/). 23 | 24 | # Case study 25 | ## Background 26 | (Full disclosure: I work at Momento at time of writing this): [Momento](https://www.gomomento.com/) 27 | has historically been a gRPC company. In a particular backend service with a 28 | fairly high message rate, the synchronization in `h2` under `tonic` was seen 29 | to be blocking threads in the `tokio` request runtime too much. This was causing 30 | task starvation and long polls. 31 | 32 | The starvation was tough to see, but it happens with `lock_contended` stacks 33 | underneath the `std::sync::Mutex` while trying to work with `h2` buffers. That 34 | mutex is okay, but when the `futex` syscall parks a thread, it takes hundreds 35 | of microseconds to get the thread going again on Momento's servers. It causes 36 | extra latency that you can't easily measure, because tasks are also not picked 37 | up promptly in these cases. 38 | 39 | I was able to get 20% greater throughput by writing [k-lock](https://github.com/kvc0/k-lock) 40 | and replacing the imports in `h2` for `std::sync::Mutex` with `k_lock::Mutex`. 41 | This import-replacement for `std::sync::Mutex` tries to be more appropriate for 42 | `tokio` servers. Basically, it uses a couple heuristics to both wake and spin 43 | more aggressively than the standard mutex. This is better for `tokio` servers, 44 | because those threads absolutely _must_ finish poll() asap, and a futex park 45 | blows the poll() budget out the window. 46 | 47 | 20% wasn't enough, and the main task threads were still getting starved. So I 48 | pulled `protosocket` out of [rmemstore](https://github.com/kvc0/rmemstore/), 49 | to try it out on Momento's servers. 50 | 51 | ## Test setup 52 | Momento has a daily latency and throughput test to monitor how changes are affecting 53 | system performance. This test uses a small server setup to more easily stress 54 | the service (as opposed to stressing the load generator). 55 | 56 | Latency is measured outside of Momento, at the client. It includes a lot of factors 57 | that are not directly under Momento control, and offers a full picture of how 58 | the service could look to users (if it were deployed in a small setup like the 59 | test). 60 | 61 | The number Momento looked at historically was `at which throughput threshold does 62 | the server pass 5 milliseconds at p99.9 tail latency?` 63 | 64 | ## Results 65 | | | Throughput | Latency | 66 | | ------------- | --------- | --------- | 67 | | **gRPC** | ![grpc throughput peaking at 57.7khz](https://github.com/user-attachments/assets/2a7c9c91-d0c5-410a-adda-d4337432c1c7) | ![grpc latency surpassing 5ms p99.9 below 20khz](https://github.com/user-attachments/assets/15e8c3ec-d4f8-4fed-a236-ae40d08f6e93) | 68 | | **protosockets** | ![protosockets throughput peaking at 75khz](https://github.com/user-attachments/assets/d1bf1bf3-3640-45d8-9a55-482844f5993a) | ![protosockets latency surpassing 5ms p99.9 above 55khz](https://github.com/user-attachments/assets/c8c90a8a-8f97-403d-b2e4-fc59eccb6b82) | 69 | 70 | Achievable throughput increased, but latency at all throughputs was significantly 71 | reduced. This improved the effective vertical scale of the reference workflow. 72 | 73 | The effective vertical scale of the small reference server was improved by 2.75x 74 | for this workflow by switching the backend protocol from gRPC to protosockets. 75 | -------------------------------------------------------------------------------- /example-messagepack/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "example-messagepack" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [[bin]] 8 | name = "example-messagepack-server" 9 | path = "src/server.rs" 10 | 11 | [[bin]] 12 | name = "example-messagepack-client" 13 | path = "src/client.rs" 14 | 15 | [dependencies] 16 | protosocket = { workspace = true } 17 | protosocket-messagepack = { workspace = true } 18 | protosocket-rpc = { workspace = true } 19 | 20 | bytes = { workspace = true } 21 | env_logger = { workspace = true } 22 | futures = { workspace = true } 23 | log = { workspace = true } 24 | serde = { workspace = true, features = ["derive"] } 25 | tokio = { workspace = true, features = ["full"] } 26 | 27 | histogram = { version = "0.11" } 28 | -------------------------------------------------------------------------------- /example-messagepack/src/client.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | sync::{atomic::AtomicUsize, Arc}, 3 | time::{Duration, Instant, SystemTime, UNIX_EPOCH}, 4 | }; 5 | 6 | use futures::StreamExt; 7 | use messages::{EchoRequest, EchoResponseKind, Request, Response, ResponseBehavior}; 8 | use protosocket_rpc::{client::RpcClient, ProtosocketControlCode}; 9 | use tokio::sync::Semaphore; 10 | 11 | mod messages; 12 | 13 | fn main() -> Result<(), Box> { 14 | static I: AtomicUsize = AtomicUsize::new(0); 15 | let runtime = tokio::runtime::Builder::new_multi_thread() 16 | .thread_name_fn(|| { 17 | format!( 18 | "app-{}", 19 | I.fetch_add(1, std::sync::atomic::Ordering::Relaxed) 20 | ) 21 | }) 22 | .worker_threads(2) 23 | .event_interval(7) 24 | .enable_all() 25 | .build()?; 26 | 27 | runtime.block_on(run_main()) 28 | } 29 | 30 | async fn run_main() -> Result<(), Box> { 31 | env_logger::init(); 32 | 33 | let response_count = Arc::new(AtomicUsize::new(0)); 34 | let latency = Arc::new(histogram::AtomicHistogram::new(7, 52).expect("histogram works")); 35 | 36 | let max_concurrent = 512; 37 | let concurrent_count = Arc::new(Semaphore::new(max_concurrent)); 38 | for _i in 0..2 { 39 | let (client, connection) = protosocket_rpc::client::connect::< 40 | protosocket_messagepack::ProtosocketMessagePackSerializer, 41 | protosocket_messagepack::ProtosocketMessagePackDeserializer, 42 | >( 43 | std::env::var("ENDPOINT") 44 | .unwrap_or_else(|_| "127.0.0.1:9000".to_string()) 45 | .parse() 46 | .expect("must use a valid socket address"), 47 | &Default::default(), 48 | ) 49 | .await?; 50 | let _connection_handle = tokio::spawn(connection); 51 | let _client_handle = tokio::spawn(generate_traffic( 52 | concurrent_count.clone(), 53 | client, 54 | response_count.clone(), 55 | latency.clone(), 56 | )); 57 | } 58 | 59 | let metrics = tokio::spawn(print_periodic_metrics( 60 | response_count, 61 | latency, 62 | concurrent_count, 63 | max_concurrent, 64 | )); 65 | 66 | tokio::select!( 67 | // _ = connection_driver => { 68 | // log::warn!("connection driver quit"); 69 | // } 70 | // _ = client_runtime => { 71 | // log::warn!("client runtime quit"); 72 | // } 73 | _ = metrics => { 74 | log::warn!("metrics runtime quit"); 75 | } 76 | ); 77 | 78 | Ok(()) 79 | } 80 | 81 | async fn print_periodic_metrics( 82 | response_count: Arc, 83 | latency: Arc, 84 | concurrent_count: Arc, 85 | max_concurrent: usize, 86 | ) { 87 | let mut interval = tokio::time::interval(Duration::from_secs(1)); 88 | loop { 89 | let start = Instant::now(); 90 | interval.tick().await; 91 | let total = response_count.swap(0, std::sync::atomic::Ordering::Relaxed); 92 | let hz = (total as f64) / start.elapsed().as_secs_f64().max(0.1); 93 | 94 | let latency = latency.drain(); 95 | let p90 = latency 96 | .percentile(0.9) 97 | .unwrap_or_default() 98 | .map(|b| *b.range().end()) 99 | .unwrap_or_default() as f64 100 | / 1000.0; 101 | let p999 = latency 102 | .percentile(0.999) 103 | .unwrap_or_default() 104 | .map(|b| *b.range().end()) 105 | .unwrap_or_default() as f64 106 | / 1000.0; 107 | let p9999 = latency 108 | .percentile(0.9999) 109 | .unwrap_or_default() 110 | .map(|b| *b.range().end()) 111 | .unwrap_or_default() as f64 112 | / 1000.0; 113 | let concurrent = max_concurrent - concurrent_count.available_permits(); 114 | eprintln!("Messages: {total:10} rate: {hz:9.1}hz p90: {p90:6.1}µs p999: {p999:6.1}µs p9999: {p9999:6.1}µs concurrency: {concurrent}"); 115 | } 116 | } 117 | 118 | async fn generate_traffic( 119 | concurrent_count: Arc, 120 | client: RpcClient, 121 | metrics_count: Arc, 122 | metrics_latency: Arc, 123 | ) { 124 | log::debug!("running traffic generator"); 125 | let mut i = 1; 126 | loop { 127 | let permit = concurrent_count 128 | .clone() 129 | .acquire_owned() 130 | .await 131 | .expect("semaphore works"); 132 | if i % 2 == 0 { 133 | match client 134 | .send_unary(Request { 135 | request_id: i, 136 | code: ProtosocketControlCode::Normal as u32, 137 | body: Some(EchoRequest { 138 | message: i.to_string(), 139 | nanotime: SystemTime::now() 140 | .duration_since(UNIX_EPOCH) 141 | .expect("time works") 142 | .as_nanos() as u64, 143 | }), 144 | response_behavior: ResponseBehavior::Unary, 145 | }) 146 | .await 147 | { 148 | Ok(completion) => { 149 | i += 1; 150 | let metrics_count = metrics_count.clone(); 151 | let metrics_latency = metrics_latency.clone(); 152 | tokio::spawn(async move { 153 | let response = completion.await.expect("response must be successful"); 154 | handle_response(response, metrics_count, metrics_latency); 155 | drop(permit); 156 | }); 157 | } 158 | Err(e) => { 159 | log::error!("send should work: {e:?}"); 160 | return; 161 | } 162 | } 163 | } else { 164 | match client 165 | .send_streaming(Request { 166 | request_id: i, 167 | code: ProtosocketControlCode::Normal as u32, 168 | body: Some(EchoRequest { 169 | message: i.to_string(), 170 | nanotime: SystemTime::now() 171 | .duration_since(UNIX_EPOCH) 172 | .expect("time works") 173 | .as_nanos() as u64, 174 | }), 175 | response_behavior: ResponseBehavior::Stream, 176 | }) 177 | .await 178 | { 179 | Ok(mut completion) => { 180 | i += 1; 181 | let metrics_count = metrics_count.clone(); 182 | let metrics_latency = metrics_latency.clone(); 183 | tokio::spawn(async move { 184 | while let Some(Ok(response)) = completion.next().await { 185 | handle_stream_response( 186 | response, 187 | metrics_count.clone(), 188 | metrics_latency.clone(), 189 | ); 190 | } 191 | drop(permit); 192 | }); 193 | } 194 | Err(e) => { 195 | log::error!("send should work: {e:?}"); 196 | return; 197 | } 198 | } 199 | } 200 | } 201 | } 202 | 203 | fn handle_response( 204 | response: Response, 205 | metrics_count: Arc, 206 | metrics_latency: Arc, 207 | ) { 208 | metrics_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); 209 | 210 | let request_id = response.request_id; 211 | assert_ne!(response.request_id, 0, "received bad message"); 212 | match response.kind { 213 | Some(EchoResponseKind::Echo(echo)) => { 214 | assert_eq!(request_id, echo.message.parse().unwrap_or_default()); 215 | 216 | let latency = SystemTime::now() 217 | .duration_since(UNIX_EPOCH) 218 | .expect("time works") 219 | .as_nanos() as u64 220 | - echo.nanotime; 221 | let _ = metrics_latency.increment(latency); 222 | } 223 | Some(EchoResponseKind::Stream(_char_response)) => { 224 | log::error!("got a stream response for a unary request"); 225 | } 226 | None => { 227 | log::warn!("no response body"); 228 | } 229 | } 230 | } 231 | 232 | fn handle_stream_response( 233 | response: Response, 234 | metrics_count: Arc, 235 | metrics_latency: Arc, 236 | ) { 237 | log::debug!("received stream response {response:?}"); 238 | metrics_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); 239 | 240 | let request_id = response.request_id; 241 | assert_ne!(response.request_id, 0, "received bad message"); 242 | match response.kind { 243 | Some(EchoResponseKind::Echo(_echo)) => { 244 | log::error!("got a unary response for a stream request"); 245 | } 246 | Some(EchoResponseKind::Stream(char_response)) => { 247 | assert_eq!( 248 | request_id.to_string() 249 | [(char_response.sequence as usize)..=(char_response.sequence as usize)], 250 | char_response.message 251 | ); 252 | 253 | let latency = SystemTime::now() 254 | .duration_since(UNIX_EPOCH) 255 | .expect("time works") 256 | .as_nanos() as u64 257 | - char_response.nanotime; 258 | let _ = metrics_latency.increment(latency); 259 | } 260 | None => { 261 | log::warn!("no response body"); 262 | } 263 | } 264 | } 265 | -------------------------------------------------------------------------------- /example-messagepack/src/messages.rs: -------------------------------------------------------------------------------- 1 | //! If you're only using rust, of course you can hand-write prost structs, but if you 2 | //! want to use a protosocket server with clients in other languages you'll want to 3 | //! generate from protos. 4 | 5 | use protosocket_rpc::ProtosocketControlCode; 6 | 7 | #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] 8 | pub struct Request { 9 | pub request_id: u64, 10 | pub code: u32, 11 | pub body: Option, 12 | pub response_behavior: ResponseBehavior, 13 | } 14 | 15 | #[derive( 16 | Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize, 17 | )] 18 | #[repr(i32)] 19 | pub enum ResponseBehavior { 20 | Unary = 0, 21 | Stream = 1, 22 | } 23 | 24 | #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] 25 | pub struct EchoRequest { 26 | pub message: String, 27 | pub nanotime: u64, 28 | } 29 | 30 | #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] 31 | pub struct Response { 32 | pub request_id: u64, 33 | pub code: u32, 34 | pub kind: Option, 35 | } 36 | 37 | #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] 38 | pub enum EchoResponseKind { 39 | Echo(EchoResponse), 40 | Stream(EchoStream), 41 | } 42 | 43 | #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] 44 | pub struct EchoResponse { 45 | pub message: String, 46 | pub nanotime: u64, 47 | } 48 | 49 | #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] 50 | pub struct EchoStream { 51 | pub message: String, 52 | pub nanotime: u64, 53 | pub sequence: u64, 54 | } 55 | 56 | impl protosocket_rpc::Message for Request { 57 | fn message_id(&self) -> u64 { 58 | self.request_id 59 | } 60 | 61 | fn control_code(&self) -> ProtosocketControlCode { 62 | ProtosocketControlCode::from_u8(self.code as u8) 63 | } 64 | 65 | fn cancelled(request_id: u64) -> Self { 66 | Request { 67 | request_id, 68 | code: ProtosocketControlCode::Cancel as u32, 69 | body: None, 70 | response_behavior: ResponseBehavior::Unary, 71 | } 72 | } 73 | 74 | fn set_message_id(&mut self, message_id: u64) { 75 | self.request_id = message_id; 76 | } 77 | 78 | fn ended(request_id: u64) -> Self { 79 | Self { 80 | request_id, 81 | code: ProtosocketControlCode::End as u32, 82 | body: None, 83 | response_behavior: ResponseBehavior::Unary, 84 | } 85 | } 86 | } 87 | 88 | impl protosocket_rpc::Message for Response { 89 | fn message_id(&self) -> u64 { 90 | self.request_id 91 | } 92 | 93 | fn control_code(&self) -> ProtosocketControlCode { 94 | ProtosocketControlCode::from_u8(self.code as u8) 95 | } 96 | 97 | fn cancelled(request_id: u64) -> Self { 98 | Response { 99 | request_id, 100 | code: ProtosocketControlCode::Cancel as u32, 101 | kind: None, 102 | } 103 | } 104 | 105 | fn set_message_id(&mut self, message_id: u64) { 106 | self.request_id = message_id 107 | } 108 | 109 | fn ended(request_id: u64) -> Self { 110 | Self { 111 | request_id, 112 | code: ProtosocketControlCode::End as u32, 113 | kind: None, 114 | } 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /example-messagepack/src/server.rs: -------------------------------------------------------------------------------- 1 | use std::sync::atomic::AtomicUsize; 2 | 3 | use futures::{future::BoxFuture, stream::BoxStream, FutureExt, Stream, StreamExt}; 4 | use messages::{EchoRequest, EchoResponse, EchoStream, Request, Response, ResponseBehavior}; 5 | use protosocket_rpc::{ 6 | server::{ConnectionService, RpcKind, SocketService}, 7 | ProtosocketControlCode, 8 | }; 9 | 10 | mod messages; 11 | 12 | fn main() -> Result<(), Box> { 13 | static I: AtomicUsize = AtomicUsize::new(0); 14 | let runtime = tokio::runtime::Builder::new_multi_thread() 15 | .thread_name_fn(|| { 16 | format!( 17 | "app-{}", 18 | I.fetch_add(1, std::sync::atomic::Ordering::Relaxed) 19 | ) 20 | }) 21 | .worker_threads(2) 22 | .event_interval(7) 23 | .enable_all() 24 | .build()?; 25 | 26 | runtime.block_on(run_main()) 27 | } 28 | 29 | #[allow(clippy::expect_used)] 30 | async fn run_main() -> Result<(), Box> { 31 | env_logger::init(); 32 | let mut server = protosocket_rpc::server::SocketRpcServer::new( 33 | std::env::var("HOST") 34 | .unwrap_or_else(|_| "0.0.0.0:9000".to_string()) 35 | .parse()?, 36 | DemoRpcSocketService, 37 | ) 38 | .await?; 39 | server.set_max_queued_outbound_messages(512); 40 | 41 | tokio::spawn(server).await??; 42 | Ok(()) 43 | } 44 | 45 | /// This is the service that will be used to handle new connections. 46 | /// It doesn't do much; yours might be simple like this too, or it might wire your per-connection 47 | /// ConnectionServices to application-wide state tracking. 48 | struct DemoRpcSocketService; 49 | impl SocketService for DemoRpcSocketService { 50 | type RequestDeserializer = protosocket_messagepack::ProtosocketMessagePackDeserializer; 51 | type ResponseSerializer = protosocket_messagepack::ProtosocketMessagePackSerializer; 52 | type ConnectionService = DemoRpcConnectionServer; 53 | 54 | fn deserializer(&self) -> Self::RequestDeserializer { 55 | Self::RequestDeserializer::default() 56 | } 57 | 58 | fn serializer(&self) -> Self::ResponseSerializer { 59 | Self::ResponseSerializer::default() 60 | } 61 | 62 | fn new_connection_service(&self, address: std::net::SocketAddr) -> Self::ConnectionService { 63 | log::info!("new connection server {address}"); 64 | DemoRpcConnectionServer { address } 65 | } 66 | } 67 | 68 | /// This is the entry point for each Connection. State per-connection is tracked, and you 69 | /// get mutable access to the service on each new rpc for state tracking. 70 | struct DemoRpcConnectionServer { 71 | address: std::net::SocketAddr, 72 | } 73 | impl ConnectionService for DemoRpcConnectionServer { 74 | type Request = Request; 75 | type Response = Response; 76 | // Ideally you'd use real Future and Stream types here for performance and debuggability. 77 | // For a demo though, it's fine to use BoxFuture and BoxStream. 78 | type UnaryFutureType = BoxFuture<'static, Response>; 79 | type StreamType = BoxStream<'static, Response>; 80 | 81 | fn new_rpc( 82 | &mut self, 83 | initiating_message: Self::Request, 84 | ) -> RpcKind { 85 | log::debug!("{} new rpc: {initiating_message:?}", self.address); 86 | let request_id = initiating_message.request_id; 87 | let behavior = initiating_message.response_behavior; 88 | match initiating_message.body { 89 | Some(echo) => match behavior { 90 | ResponseBehavior::Unary => RpcKind::Unary(echo_request(request_id, echo).boxed()), 91 | ResponseBehavior::Stream => { 92 | RpcKind::Streaming(echo_stream(request_id, echo).boxed()) 93 | } 94 | }, 95 | None => { 96 | // No completion messages will be sent for this message 97 | log::warn!( 98 | "{request_id} no request in rpc body. This may cause a client memory leak." 99 | ); 100 | RpcKind::Unknown 101 | } 102 | } 103 | } 104 | } 105 | 106 | async fn echo_request(request_id: u64, echo: EchoRequest) -> Response { 107 | Response { 108 | request_id, 109 | code: ProtosocketControlCode::Normal as u32, 110 | kind: Some(messages::EchoResponseKind::Echo(EchoResponse { 111 | message: echo.message, 112 | nanotime: echo.nanotime, 113 | })), 114 | } 115 | } 116 | 117 | fn echo_stream(request_id: u64, echo: EchoRequest) -> impl Stream { 118 | let nanotime = echo.nanotime; 119 | futures::stream::iter(echo.message.into_bytes().into_iter().enumerate().map( 120 | move |(sequence, c)| Response { 121 | request_id, 122 | code: ProtosocketControlCode::Normal as u32, 123 | kind: Some(messages::EchoResponseKind::Stream(EchoStream { 124 | message: (c as char).to_string(), 125 | nanotime, 126 | sequence: sequence as u64, 127 | })), 128 | }, 129 | )) 130 | } 131 | -------------------------------------------------------------------------------- /example-proto/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "example-proto" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [[bin]] 8 | name = "example-proto-server" 9 | path = "src/server.rs" 10 | 11 | [[bin]] 12 | name = "example-proto-client" 13 | path = "src/client.rs" 14 | 15 | [dependencies] 16 | protosocket = { workspace = true } 17 | protosocket-prost = { workspace = true } 18 | protosocket-rpc = { workspace = true } 19 | protosocket-server = { workspace = true } 20 | 21 | bytes = { workspace = true } 22 | env_logger = { workspace = true } 23 | futures = { workspace = true } 24 | log = { workspace = true } 25 | prost = { workspace = true } 26 | tokio = { workspace = true, features = ["full"] } 27 | 28 | histogram = { version = "0.11" } 29 | -------------------------------------------------------------------------------- /example-proto/src/client.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | sync::{atomic::AtomicUsize, Arc}, 3 | time::{Duration, Instant, SystemTime, UNIX_EPOCH}, 4 | }; 5 | 6 | use futures::{stream::FuturesUnordered, task::SpawnExt, StreamExt}; 7 | use messages::{EchoRequest, EchoResponseKind, Request, Response, ResponseBehavior}; 8 | use protosocket_rpc::{client::RpcClient, ProtosocketControlCode}; 9 | use tokio::sync::Semaphore; 10 | 11 | mod messages; 12 | 13 | fn main() -> Result<(), Box> { 14 | static I: AtomicUsize = AtomicUsize::new(0); 15 | let runtime = tokio::runtime::Builder::new_multi_thread() 16 | .thread_name_fn(|| { 17 | format!( 18 | "app-{}", 19 | I.fetch_add(1, std::sync::atomic::Ordering::Relaxed) 20 | ) 21 | }) 22 | .worker_threads(2) 23 | .event_interval(7) 24 | .enable_all() 25 | .build()?; 26 | 27 | runtime.block_on(run_main()) 28 | } 29 | 30 | async fn run_main() -> Result<(), Box> { 31 | env_logger::init(); 32 | 33 | let response_count = Arc::new(AtomicUsize::new(0)); 34 | let latency = Arc::new(histogram::AtomicHistogram::new(7, 52).expect("histogram works")); 35 | 36 | let max_concurrent = 512; 37 | let concurrent_count = Arc::new(Semaphore::new(max_concurrent)); 38 | for _i in 0..2 { 39 | let (client, connection) = protosocket_rpc::client::connect::< 40 | protosocket_prost::ProstSerializer, 41 | protosocket_prost::ProstSerializer, 42 | >( 43 | std::env::var("ENDPOINT") 44 | .unwrap_or_else(|_| "127.0.0.1:9000".to_string()) 45 | .parse() 46 | .expect("must use a valid socket address"), 47 | &Default::default(), 48 | ) 49 | .await?; 50 | let _connection_handle = tokio::spawn(connection); 51 | let _client_handle = tokio::spawn(generate_traffic( 52 | concurrent_count.clone(), 53 | client, 54 | response_count.clone(), 55 | latency.clone(), 56 | )); 57 | } 58 | 59 | let metrics = tokio::spawn(print_periodic_metrics( 60 | response_count, 61 | latency, 62 | concurrent_count, 63 | max_concurrent, 64 | )); 65 | 66 | tokio::select!( 67 | // _ = connection_driver => { 68 | // log::warn!("connection driver quit"); 69 | // } 70 | // _ = client_runtime => { 71 | // log::warn!("client runtime quit"); 72 | // } 73 | _ = metrics => { 74 | log::warn!("metrics runtime quit"); 75 | } 76 | ); 77 | 78 | Ok(()) 79 | } 80 | 81 | async fn print_periodic_metrics( 82 | response_count: Arc, 83 | latency: Arc, 84 | concurrent_count: Arc, 85 | max_concurrent: usize, 86 | ) { 87 | let mut interval = tokio::time::interval(Duration::from_secs(1)); 88 | loop { 89 | let start = Instant::now(); 90 | interval.tick().await; 91 | let total = response_count.swap(0, std::sync::atomic::Ordering::Relaxed); 92 | let hz = (total as f64) / start.elapsed().as_secs_f64().max(0.1); 93 | 94 | let latency = latency.drain(); 95 | let p90 = latency 96 | .percentile(0.9) 97 | .unwrap_or_default() 98 | .map(|b| *b.range().end()) 99 | .unwrap_or_default() as f64 100 | / 1000.0; 101 | let p999 = latency 102 | .percentile(0.999) 103 | .unwrap_or_default() 104 | .map(|b| *b.range().end()) 105 | .unwrap_or_default() as f64 106 | / 1000.0; 107 | let p9999 = latency 108 | .percentile(0.9999) 109 | .unwrap_or_default() 110 | .map(|b| *b.range().end()) 111 | .unwrap_or_default() as f64 112 | / 1000.0; 113 | let concurrent = max_concurrent - concurrent_count.available_permits(); 114 | eprintln!("Messages: {total:10} rate: {hz:9.1}hz p90: {p90:6.1}µs p999: {p999:6.1}µs p9999: {p9999:6.1}µs concurrency: {concurrent}"); 115 | } 116 | } 117 | 118 | async fn generate_traffic( 119 | concurrent_count: Arc, 120 | client: RpcClient, 121 | metrics_count: Arc, 122 | metrics_latency: Arc, 123 | ) { 124 | log::debug!("running traffic generator"); 125 | let mut i = 1; 126 | let mut wip = FuturesUnordered::new(); 127 | loop { 128 | let permit = tokio::select! { 129 | permit = concurrent_count 130 | .clone() 131 | .acquire_owned() => { 132 | permit.expect("semaphore works") 133 | } 134 | _ = wip.select_next_some() => { 135 | // completed one 136 | continue 137 | } 138 | }; 139 | 140 | if i % 2 == 0 { 141 | match client 142 | .send_unary(Request { 143 | request_id: i, 144 | code: ProtosocketControlCode::Normal as u32, 145 | body: Some(EchoRequest { 146 | message: i.to_string(), 147 | nanotime: SystemTime::now() 148 | .duration_since(UNIX_EPOCH) 149 | .expect("time works") 150 | .as_nanos() as u64, 151 | }), 152 | response_behavior: ResponseBehavior::Unary as i32, 153 | }) 154 | .await 155 | { 156 | Ok(completion) => { 157 | i += 1; 158 | let metrics_count = metrics_count.clone(); 159 | let metrics_latency = metrics_latency.clone(); 160 | wip.spawn(async move { 161 | let response = completion.await.expect("response must be successful"); 162 | handle_response(response, metrics_count, metrics_latency); 163 | drop(permit); 164 | }) 165 | .expect("can spawn"); 166 | } 167 | Err(e) => { 168 | log::error!("send should work: {e:?}"); 169 | return; 170 | } 171 | } 172 | } else { 173 | match client 174 | .send_streaming(Request { 175 | request_id: i, 176 | code: ProtosocketControlCode::Normal as u32, 177 | body: Some(EchoRequest { 178 | message: i.to_string(), 179 | nanotime: SystemTime::now() 180 | .duration_since(UNIX_EPOCH) 181 | .expect("time works") 182 | .as_nanos() as u64, 183 | }), 184 | response_behavior: ResponseBehavior::Stream as i32, 185 | }) 186 | .await 187 | { 188 | Ok(mut completion) => { 189 | i += 1; 190 | let metrics_count = metrics_count.clone(); 191 | let metrics_latency = metrics_latency.clone(); 192 | wip.spawn(async move { 193 | while let Some(Ok(response)) = completion.next().await { 194 | handle_stream_response( 195 | response, 196 | metrics_count.clone(), 197 | metrics_latency.clone(), 198 | ); 199 | } 200 | drop(permit); 201 | }) 202 | .expect("can spawn"); 203 | } 204 | Err(e) => { 205 | log::error!("send should work: {e:?}"); 206 | return; 207 | } 208 | } 209 | } 210 | } 211 | } 212 | 213 | fn handle_response( 214 | response: Response, 215 | metrics_count: Arc, 216 | metrics_latency: Arc, 217 | ) { 218 | metrics_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); 219 | 220 | let request_id = response.request_id; 221 | assert_ne!(response.request_id, 0, "received bad message"); 222 | match response.kind { 223 | Some(EchoResponseKind::Echo(echo)) => { 224 | assert_eq!(request_id, echo.message.parse().unwrap_or_default()); 225 | 226 | let latency = SystemTime::now() 227 | .duration_since(UNIX_EPOCH) 228 | .expect("time works") 229 | .as_nanos() as u64 230 | - echo.nanotime; 231 | let _ = metrics_latency.increment(latency); 232 | } 233 | Some(EchoResponseKind::Stream(_char_response)) => { 234 | log::error!("got a stream response for a unary request"); 235 | } 236 | None => { 237 | log::warn!("no response body"); 238 | } 239 | } 240 | } 241 | 242 | fn handle_stream_response( 243 | response: Response, 244 | metrics_count: Arc, 245 | metrics_latency: Arc, 246 | ) { 247 | log::debug!("received stream response {response:?}"); 248 | metrics_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); 249 | 250 | let request_id = response.request_id; 251 | assert_ne!(response.request_id, 0, "received bad message"); 252 | match response.kind { 253 | Some(EchoResponseKind::Echo(_echo)) => { 254 | log::error!("got a unary response for a stream request"); 255 | } 256 | Some(EchoResponseKind::Stream(char_response)) => { 257 | let places = request_id.ilog(10); 258 | let place = places - char_response.sequence as u32; 259 | let column = (request_id / 10u64.pow(place)) % 10; 260 | 261 | assert_eq!(Ok(column), char_response.message.parse()); 262 | 263 | if place == places - 1 { 264 | let latency = SystemTime::now() 265 | .duration_since(UNIX_EPOCH) 266 | .expect("time works") 267 | .as_nanos() as u64 268 | - char_response.nanotime; 269 | let _ = metrics_latency.increment(latency); 270 | } 271 | } 272 | None => { 273 | log::warn!("no response body"); 274 | } 275 | } 276 | } 277 | -------------------------------------------------------------------------------- /example-proto/src/messages.rs: -------------------------------------------------------------------------------- 1 | //! If you're only using rust, of course you can hand-write prost structs, but if you 2 | //! want to use a protosocket server with clients in other languages you'll want to 3 | //! generate from protos. 4 | 5 | use protosocket_rpc::ProtosocketControlCode; 6 | 7 | #[derive(Clone, PartialEq, Eq, prost::Message)] 8 | pub struct Request { 9 | #[prost(uint64, tag = "1")] 10 | pub request_id: u64, 11 | #[prost(uint32, tag = "2")] 12 | pub code: u32, 13 | #[prost(message, tag = "3")] 14 | pub body: Option, 15 | #[prost(enumeration = "ResponseBehavior", tag = "4")] 16 | pub response_behavior: i32, 17 | } 18 | 19 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, prost::Enumeration)] 20 | #[repr(i32)] 21 | pub enum ResponseBehavior { 22 | Unary = 0, 23 | Stream = 1, 24 | } 25 | 26 | #[derive(Clone, PartialEq, Eq, prost::Message)] 27 | pub struct EchoRequest { 28 | #[prost(string, tag = "1")] 29 | pub message: String, 30 | #[prost(uint64, tag = "2")] 31 | pub nanotime: u64, 32 | } 33 | 34 | #[derive(Clone, PartialEq, Eq, prost::Message)] 35 | pub struct Response { 36 | #[prost(uint64, tag = "1")] 37 | pub request_id: u64, 38 | #[prost(uint32, tag = "2")] 39 | pub code: u32, 40 | #[prost(oneof = "EchoResponseKind", tags = "3, 4")] 41 | pub kind: Option, 42 | } 43 | 44 | #[derive(Clone, PartialEq, Eq, prost::Oneof)] 45 | pub enum EchoResponseKind { 46 | #[prost(message, tag = "3")] 47 | Echo(EchoResponse), 48 | #[prost(message, tag = "4")] 49 | Stream(EchoStream), 50 | } 51 | 52 | #[derive(Clone, PartialEq, Eq, prost::Message)] 53 | pub struct EchoResponse { 54 | #[prost(string, tag = "1")] 55 | pub message: String, 56 | #[prost(uint64, tag = "2")] 57 | pub nanotime: u64, 58 | } 59 | 60 | #[derive(Clone, PartialEq, Eq, prost::Message)] 61 | pub struct EchoStream { 62 | #[prost(string, tag = "1")] 63 | pub message: String, 64 | #[prost(uint64, tag = "2")] 65 | pub nanotime: u64, 66 | #[prost(uint64, tag = "3")] 67 | pub sequence: u64, 68 | } 69 | 70 | impl protosocket_rpc::Message for Request { 71 | fn message_id(&self) -> u64 { 72 | self.request_id 73 | } 74 | 75 | fn control_code(&self) -> ProtosocketControlCode { 76 | ProtosocketControlCode::from_u8(self.code as u8) 77 | } 78 | 79 | fn cancelled(request_id: u64) -> Self { 80 | Request { 81 | request_id, 82 | code: ProtosocketControlCode::Cancel as u32, 83 | body: None, 84 | response_behavior: ResponseBehavior::Unary as i32, 85 | } 86 | } 87 | 88 | fn set_message_id(&mut self, message_id: u64) { 89 | self.request_id = message_id; 90 | } 91 | 92 | fn ended(request_id: u64) -> Self { 93 | Self { 94 | request_id, 95 | code: ProtosocketControlCode::End as u32, 96 | body: None, 97 | response_behavior: ResponseBehavior::Unary as i32, 98 | } 99 | } 100 | } 101 | 102 | impl protosocket_rpc::Message for Response { 103 | fn message_id(&self) -> u64 { 104 | self.request_id 105 | } 106 | 107 | fn control_code(&self) -> ProtosocketControlCode { 108 | ProtosocketControlCode::from_u8(self.code as u8) 109 | } 110 | 111 | fn cancelled(request_id: u64) -> Self { 112 | Response { 113 | request_id, 114 | code: ProtosocketControlCode::Cancel as u32, 115 | kind: None, 116 | } 117 | } 118 | 119 | fn set_message_id(&mut self, message_id: u64) { 120 | self.request_id = message_id 121 | } 122 | 123 | fn ended(request_id: u64) -> Self { 124 | Self { 125 | request_id, 126 | code: ProtosocketControlCode::End as u32, 127 | kind: None, 128 | } 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /example-proto/src/server.rs: -------------------------------------------------------------------------------- 1 | use std::sync::atomic::AtomicUsize; 2 | 3 | use futures::{future::BoxFuture, stream::BoxStream, FutureExt, Stream, StreamExt}; 4 | use messages::{EchoRequest, EchoResponse, EchoStream, Request, Response, ResponseBehavior}; 5 | use protosocket_prost::ProstSerializer; 6 | use protosocket_rpc::{ 7 | server::{ConnectionService, RpcKind, SocketService}, 8 | ProtosocketControlCode, 9 | }; 10 | 11 | mod messages; 12 | 13 | fn main() -> Result<(), Box> { 14 | static I: AtomicUsize = AtomicUsize::new(0); 15 | let runtime = tokio::runtime::Builder::new_multi_thread() 16 | .thread_name_fn(|| { 17 | format!( 18 | "app-{}", 19 | I.fetch_add(1, std::sync::atomic::Ordering::Relaxed) 20 | ) 21 | }) 22 | .worker_threads(2) 23 | .event_interval(7) 24 | .enable_all() 25 | .build()?; 26 | 27 | runtime.block_on(run_main()) 28 | } 29 | 30 | #[allow(clippy::expect_used)] 31 | async fn run_main() -> Result<(), Box> { 32 | env_logger::init(); 33 | let mut server = protosocket_rpc::server::SocketRpcServer::new( 34 | std::env::var("HOST") 35 | .unwrap_or_else(|_| "0.0.0.0:9000".to_string()) 36 | .parse()?, 37 | DemoRpcSocketService, 38 | ) 39 | .await?; 40 | server.set_max_queued_outbound_messages(512); 41 | 42 | tokio::spawn(server).await??; 43 | Ok(()) 44 | } 45 | 46 | /// This is the service that will be used to handle new connections. 47 | /// It doesn't do much; yours might be simple like this too, or it might wire your per-connection 48 | /// ConnectionServices to application-wide state tracking. 49 | struct DemoRpcSocketService; 50 | impl SocketService for DemoRpcSocketService { 51 | type RequestDeserializer = ProstSerializer; 52 | type ResponseSerializer = ProstSerializer; 53 | type ConnectionService = DemoRpcConnectionServer; 54 | 55 | fn deserializer(&self) -> Self::RequestDeserializer { 56 | Self::RequestDeserializer::default() 57 | } 58 | 59 | fn serializer(&self) -> Self::ResponseSerializer { 60 | Self::ResponseSerializer::default() 61 | } 62 | 63 | fn new_connection_service(&self, address: std::net::SocketAddr) -> Self::ConnectionService { 64 | log::info!("new connection server {address}"); 65 | DemoRpcConnectionServer { address } 66 | } 67 | } 68 | 69 | /// This is the entry point for each Connection. State per-connection is tracked, and you 70 | /// get mutable access to the service on each new rpc for state tracking. 71 | struct DemoRpcConnectionServer { 72 | address: std::net::SocketAddr, 73 | } 74 | impl ConnectionService for DemoRpcConnectionServer { 75 | type Request = Request; 76 | type Response = Response; 77 | // Ideally you'd use real Future and Stream types here for performance and debuggability. 78 | // For a demo though, it's fine to use BoxFuture and BoxStream. 79 | type UnaryFutureType = BoxFuture<'static, Response>; 80 | type StreamType = BoxStream<'static, Response>; 81 | 82 | fn new_rpc( 83 | &mut self, 84 | initiating_message: Self::Request, 85 | ) -> RpcKind { 86 | log::debug!("{} new rpc: {initiating_message:?}", self.address); 87 | let request_id = initiating_message.request_id; 88 | let behavior = initiating_message.response_behavior(); 89 | match initiating_message.body { 90 | Some(echo) => match behavior { 91 | ResponseBehavior::Unary => RpcKind::Unary(echo_request(request_id, echo).boxed()), 92 | ResponseBehavior::Stream => { 93 | RpcKind::Streaming(echo_stream(request_id, echo).boxed()) 94 | } 95 | }, 96 | None => { 97 | // No completion messages will be sent for this message 98 | log::warn!( 99 | "{request_id} no request in rpc body. This may cause a client memory leak." 100 | ); 101 | RpcKind::Unknown 102 | } 103 | } 104 | } 105 | } 106 | 107 | async fn echo_request(request_id: u64, echo: EchoRequest) -> Response { 108 | Response { 109 | request_id, 110 | code: ProtosocketControlCode::Normal as u32, 111 | kind: Some(messages::EchoResponseKind::Echo(EchoResponse { 112 | message: echo.message, 113 | nanotime: echo.nanotime, 114 | })), 115 | } 116 | } 117 | 118 | fn echo_stream(request_id: u64, echo: EchoRequest) -> impl Stream { 119 | let nanotime = echo.nanotime; 120 | futures::stream::iter(echo.message.into_bytes().into_iter().enumerate().map( 121 | move |(sequence, c)| Response { 122 | request_id, 123 | code: ProtosocketControlCode::Normal as u32, 124 | kind: Some(messages::EchoResponseKind::Stream(EchoStream { 125 | message: (c as char).to_string(), 126 | nanotime, 127 | sequence: sequence as u64, 128 | })), 129 | }, 130 | )) 131 | } 132 | -------------------------------------------------------------------------------- /example-telnet/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "example" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | protosocket = { workspace = true } 9 | protosocket-server = { workspace = true } 10 | 11 | bytes = { workspace = true } 12 | env_logger = { workspace = true } 13 | futures = { workspace = true } 14 | log = { workspace = true } 15 | tokio = { workspace = true, features = ["full"] } 16 | -------------------------------------------------------------------------------- /example-telnet/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | io::Read, 3 | sync::{atomic::AtomicUsize, Arc}, 4 | time::Duration, 5 | }; 6 | 7 | use protosocket::{ 8 | ConnectionBindings, DeserializeError, Deserializer, MessageReactor, ReactorStatus, Serializer, 9 | }; 10 | use protosocket_server::{ProtosocketServer, ServerConnector}; 11 | 12 | #[allow(clippy::expect_used)] 13 | #[tokio::main] 14 | async fn main() -> Result<(), Box> { 15 | env_logger::init(); 16 | 17 | let server_context = ServerContext::default(); 18 | let server = ProtosocketServer::new( 19 | "127.0.0.1:9000".parse()?, 20 | tokio::runtime::Handle::current(), 21 | server_context, 22 | ) 23 | .await?; 24 | 25 | tokio::spawn(server).await??; 26 | Ok(()) 27 | } 28 | 29 | #[derive(Default, Clone)] 30 | struct ServerContext { 31 | _connections: Arc, 32 | } 33 | 34 | impl ServerConnector for ServerContext { 35 | type Bindings = StringContext; 36 | 37 | fn serializer(&self) -> ::Serializer { 38 | StringSerializer 39 | } 40 | 41 | fn deserializer(&self) -> ::Deserializer { 42 | StringSerializer 43 | } 44 | 45 | fn new_reactor( 46 | &self, 47 | optional_outbound: tokio::sync::mpsc::Sender< 48 | <::Serializer as Serializer>::Message, 49 | >, 50 | ) -> ::Reactor { 51 | StringReactor { 52 | outbound: optional_outbound, 53 | } 54 | } 55 | } 56 | 57 | struct StringReactor { 58 | outbound: tokio::sync::mpsc::Sender, 59 | } 60 | impl MessageReactor for StringReactor { 61 | type Inbound = String; 62 | 63 | fn on_inbound_messages( 64 | &mut self, 65 | messages: impl IntoIterator, 66 | ) -> ReactorStatus { 67 | for mut message in messages.into_iter() { 68 | let outbound = self.outbound.clone(); 69 | tokio::spawn(async move { 70 | let seconds: u64 = message 71 | .split_ascii_whitespace() 72 | .next() 73 | .unwrap_or("0") 74 | .parse() 75 | .unwrap_or(0); 76 | tokio::time::sleep(Duration::from_secs(seconds)).await; 77 | message.push_str(" RAN"); 78 | if let Err(e) = outbound.send(message).await { 79 | log::error!("send error: {e:?}"); 80 | } 81 | }); 82 | } 83 | ReactorStatus::Continue 84 | } 85 | } 86 | 87 | struct StringContext; 88 | 89 | impl ConnectionBindings for StringContext { 90 | type Deserializer = StringSerializer; 91 | type Serializer = StringSerializer; 92 | type Reactor = StringReactor; 93 | } 94 | 95 | struct StringSerializer; 96 | 97 | impl Serializer for StringSerializer { 98 | type Message = String; 99 | 100 | fn encode(&mut self, mut response: Self::Message, buffer: &mut Vec) { 101 | response.push_str(" ENCODED\n"); 102 | buffer.extend_from_slice(response.as_bytes()); 103 | } 104 | } 105 | impl Deserializer for StringSerializer { 106 | type Message = String; 107 | 108 | fn decode( 109 | &mut self, 110 | buffer: impl bytes::Buf, 111 | ) -> std::result::Result<(usize, Self::Message), DeserializeError> { 112 | let mut read_buffer: [u8; 1] = [0; 1]; 113 | let read = buffer 114 | .reader() 115 | .read(&mut read_buffer) 116 | .map_err(|_e| DeserializeError::InvalidBuffer)?; 117 | match String::from_utf8(read_buffer.to_vec()) { 118 | Ok(s) => { 119 | let mut s = s.trim().to_string(); 120 | if s.is_empty() { 121 | Err(DeserializeError::SkipMessage { distance: read }) 122 | } else { 123 | s.push_str(" DECODED"); 124 | Ok((read, s)) 125 | } 126 | } 127 | Err(e) => { 128 | log::debug!("invalid message {e:?}"); 129 | Err(DeserializeError::InvalidBuffer) 130 | } 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /protosocket-connection/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "protosocket" 3 | description = "Message-oriented nonblocking tcp stream" 4 | version.workspace = true 5 | edition.workspace = true 6 | license.workspace = true 7 | authors.workspace = true 8 | readme.workspace = true 9 | repository.workspace = true 10 | keywords.workspace = true 11 | categories.workspace = true 12 | 13 | [dependencies] 14 | bytes = { workspace = true } 15 | futures = { workspace = true } 16 | log = { workspace = true } 17 | serde = { workspace = true, optional = true } 18 | thiserror = { workspace = true } 19 | tokio = { workspace = true } 20 | tokio-util = { workspace = true } 21 | -------------------------------------------------------------------------------- /protosocket-connection/README.md: -------------------------------------------------------------------------------- 1 | # protosocket-connection 2 | 3 | `protosocket-connection` provides a flexible, asynchronous TCP connection handler. It's designed to efficiently manage bidirectional, message-oriented TCP streams with customizable serialization and deserialization. 4 | 5 | ## Key Features 6 | - Low abstraction - no http or higher level constructs 7 | - Asynchronous I/O using `mio` via `tokio` 8 | - Customizable message types through `ConnectionBindings` 9 | - Efficient buffer management and flexible error handling 10 | 11 | ## Flow Diagrams 12 | 13 | The `poll()` function on `connection.rs` controls the lifecycle of the entire connection. You're recommended to read individual comments on the code to understand the flow, but below is a sequence diagram to get you started: 14 | 15 | ```mermaid 16 | sequenceDiagram 17 | participant P as Poll 18 | participant IS as Inbound Socket 19 | participant D as Deserializer 20 | participant R as Reactor 21 | participant S as Serializer 22 | participant OS as Outbound Socket 23 | 24 | P->>IS: Read from inbound socket 25 | IS->>D: Raw data 26 | D->>P: Deserialize inbound messages 27 | P->>R: Submit inbound messages 28 | R-->>P: Process messages 29 | P->>S: Prepare and serialize outbound messages 30 | S->>P: Serialized outbound bytes 31 | P->>OS: Write to outbound socket 32 | ``` -------------------------------------------------------------------------------- /protosocket-connection/src/connection.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::VecDeque, 3 | future::Future, 4 | io::IoSlice, 5 | pin::Pin, 6 | task::{Context, Poll}, 7 | }; 8 | 9 | use tokio::sync::mpsc; 10 | 11 | use crate::{ 12 | interrupted, 13 | types::{ConnectionBindings, DeserializeError, MessageReactor, ReactorStatus}, 14 | would_block, Deserializer, Serializer, 15 | }; 16 | 17 | /// A bidirectional, message-oriented tcp stream wrapper. 18 | /// 19 | /// Connections are Futures that you spawn. 20 | /// To send messages, you push them into the outbound message stream. 21 | /// To receive messages, you implement a `MessageReactor`. Inbound messages are not 22 | /// wrapped in a Stream, in order to avoid an extra layer of async buffering. If you 23 | /// need to buffer messages or forward them to a Stream, you can do so in the reactor. 24 | pub struct Connection { 25 | stream: tokio::net::TcpStream, 26 | address: std::net::SocketAddr, 27 | outbound_messages: mpsc::Receiver<::Message>, 28 | outbound_message_buffer: Vec<::Message>, 29 | inbound_messages: Vec<::Message>, 30 | serializer_buffers: Vec>, 31 | send_buffer: VecDeque>, 32 | receive_buffer_slice_end: usize, 33 | receive_buffer_start_offset: usize, 34 | receive_buffer: Vec, 35 | receive_buffer_swap: Vec, 36 | max_buffer_length: usize, 37 | max_queued_send_messages: usize, 38 | deserializer: Bindings::Deserializer, 39 | serializer: Bindings::Serializer, 40 | reactor: Bindings::Reactor, 41 | } 42 | 43 | impl std::fmt::Display for Connection { 44 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 45 | let read_start = self.receive_buffer_start_offset; 46 | let read_end = self.receive_buffer_slice_end; 47 | let read_capacity = self.receive_buffer.len(); 48 | let write_queue = self.send_buffer.len(); 49 | let write_length: usize = self.send_buffer.iter().map(|b| b.len()).sum(); 50 | let address = self.address; 51 | write!(f, "Connection: {address} {{read{{start: {read_start}, end: {read_end}, capacity: {read_capacity}}}, write{{queue: {write_queue}, length: {write_length}}}}}") 52 | } 53 | } 54 | 55 | impl Unpin for Connection {} 56 | 57 | impl Future for Connection { 58 | type Output = (); 59 | 60 | /// Take a look at ConnectionBindings for the type definitions used by the Connection 61 | /// 62 | /// This method performs the following steps: 63 | /// 64 | /// 1. Check for read readiness and read into the receive_buffer (up to max_buffer_length). 65 | /// 2. Deserialize the read bytes into Messages and store them in the inbound_messages queue. 66 | /// 3. Process all messages in the inbound queue using the user-provided MessageReactor. 67 | /// 4. Serialize messages from outbound_messages queue, up to max_queued_send_messages. 68 | /// 5. Check for write readiness and send serialized messages. 69 | fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll { 70 | // Step 1: Check if there's space in the receive buffer and the stream is ready for reading 71 | if let Some(early_out) = self.as_mut().poll_receive(context) { 72 | return early_out; 73 | } 74 | 75 | // Step 2: Deserialize read bytes into messages 76 | if let Some(early_out) = self.as_mut().poll_deserialize(context) { 77 | return early_out; 78 | } 79 | 80 | // Step 3: Process inbound messages with the Connection's MessageReactor 81 | if let Some(early_out) = self.poll_react() { 82 | return early_out; 83 | } 84 | 85 | // Step 4: Prepare outbound messages for sending. We serialize the outbound bytes here as per the 86 | // Connection's serializer 87 | if let Poll::Ready(early_out) = self.poll_serialize_outbound_messages(context) { 88 | return Poll::Ready(early_out); 89 | } 90 | 91 | // Step 5: Send outbound messages if the stream is ready for writing 92 | if let Some(early_out) = self.poll_outbound(context) { 93 | return early_out; 94 | } 95 | 96 | Poll::Pending 97 | } 98 | } 99 | 100 | impl Drop for Connection { 101 | fn drop(&mut self) { 102 | log::debug!("connection dropped") 103 | } 104 | } 105 | 106 | #[derive(Debug)] 107 | enum ReadBufferState { 108 | /// Done consuming until external liveness is signaled 109 | WaitingForMore, 110 | /// Need to eagerly wake up again 111 | PartiallyConsumed, 112 | /// Connection is disconnected or is to be disconnected 113 | Disconnected, 114 | /// Disconnected with an io error 115 | Error(std::io::Error), 116 | } 117 | 118 | impl Connection 119 | where 120 | ::Message: Send, 121 | { 122 | /// Create a new protosocket Connection with the given stream and reactor. 123 | /// 124 | /// Probably you are interested in the `protosocket-server` or `protosocket-prost` crates. 125 | #[allow(clippy::type_complexity, clippy::too_many_arguments)] 126 | pub fn new( 127 | stream: tokio::net::TcpStream, 128 | address: std::net::SocketAddr, 129 | deserializer: Bindings::Deserializer, 130 | serializer: Bindings::Serializer, 131 | max_buffer_length: usize, 132 | max_queued_send_messages: usize, 133 | outbound_messages: mpsc::Receiver<::Message>, 134 | reactor: Bindings::Reactor, 135 | ) -> Connection { 136 | // outbound must be queued so it can be called from any context 137 | Self { 138 | stream, 139 | address, 140 | outbound_messages, 141 | outbound_message_buffer: Vec::new(), 142 | inbound_messages: Vec::with_capacity(max_queued_send_messages), 143 | send_buffer: Default::default(), 144 | serializer_buffers: Vec::from_iter((0..1).map(|_| Vec::new())), 145 | receive_buffer: Vec::new(), 146 | receive_buffer_swap: Vec::new(), 147 | max_buffer_length, 148 | receive_buffer_start_offset: 0, 149 | max_queued_send_messages, 150 | receive_buffer_slice_end: 0, 151 | deserializer, 152 | serializer, 153 | reactor, 154 | } 155 | } 156 | 157 | /// ensure buffer state and read from the inbound stream 158 | fn read_inbound(&mut self) -> ReadBufferState { 159 | const BUFFER_INCREMENT: usize = 2 << 20; 160 | if self.receive_buffer.len() < self.max_buffer_length 161 | && self.receive_buffer.len() - self.receive_buffer_slice_end < BUFFER_INCREMENT 162 | { 163 | self.receive_buffer.reserve(BUFFER_INCREMENT); 164 | // SAFETY: This is a buffer, and u8 is not read until after the read syscall returns. Read initializes the buffer values. 165 | // I reserved the additional space above, so the additional space is valid. 166 | // This was done because resizing the buffer shows up on heat maps. 167 | #[allow(clippy::uninit_vec)] 168 | unsafe { 169 | self.receive_buffer 170 | .set_len(self.receive_buffer.len() + BUFFER_INCREMENT) 171 | }; 172 | } 173 | 174 | if 0 < self.receive_buffer.len() - self.receive_buffer_slice_end { 175 | // We can (maybe) read from the connection. 176 | self.read_from_stream() 177 | } else { 178 | log::debug!("receive is full {self}"); 179 | ReadBufferState::WaitingForMore 180 | } 181 | } 182 | 183 | /// process the receive buffer, deserializing bytes into messages 184 | fn read_inbound_messages_into_read_queue(&mut self) -> ReadBufferState { 185 | let state = loop { 186 | if self.receive_buffer_start_offset == self.receive_buffer_slice_end { 187 | break ReadBufferState::WaitingForMore; 188 | } 189 | if self.inbound_messages.capacity() == self.inbound_messages.len() { 190 | // can't accept any more inbound messages right now 191 | log::debug!("full batch of messages read from the socket"); 192 | break ReadBufferState::PartiallyConsumed; 193 | } 194 | 195 | let buffer = &self.receive_buffer 196 | [self.receive_buffer_start_offset..self.receive_buffer_slice_end]; 197 | log::trace!("decode {buffer:?}"); 198 | match self.deserializer.decode(buffer) { 199 | Ok((length, message)) => { 200 | self.receive_buffer_start_offset += length; 201 | self.inbound_messages.push(message); 202 | } 203 | Err(e) => match e { 204 | DeserializeError::IncompleteBuffer { next_message_size } => { 205 | if self.max_buffer_length < next_message_size { 206 | log::error!("tried to receive message that is too long. Resetting connection - max: {}, requested: {}", self.max_buffer_length, next_message_size); 207 | return ReadBufferState::Disconnected; 208 | } 209 | if self.max_buffer_length 210 | < self.receive_buffer_slice_end + next_message_size 211 | { 212 | let length = 213 | self.receive_buffer_slice_end - self.receive_buffer_start_offset; 214 | log::debug!( 215 | "rotating {}b of buffer to make room for next message {}b", 216 | length, 217 | next_message_size 218 | ); 219 | self.receive_buffer_swap.clear(); 220 | self.receive_buffer_swap.extend_from_slice( 221 | &self.receive_buffer[self.receive_buffer_start_offset 222 | ..self.receive_buffer_slice_end], 223 | ); 224 | std::mem::swap(&mut self.receive_buffer, &mut self.receive_buffer_swap); 225 | self.receive_buffer_start_offset = 0; 226 | self.receive_buffer_slice_end = length; 227 | } 228 | log::debug!("waiting for the next message of length {next_message_size}"); 229 | break ReadBufferState::WaitingForMore; 230 | } 231 | DeserializeError::InvalidBuffer => { 232 | log::error!("message was invalid - broken stream"); 233 | return ReadBufferState::Disconnected; 234 | } 235 | DeserializeError::SkipMessage { distance } => { 236 | if self.receive_buffer_slice_end - self.receive_buffer_start_offset 237 | < distance 238 | { 239 | log::trace!("cannot skip yet, need to read more. Skipping: {distance}, remaining:{}", self.receive_buffer_slice_end - self.receive_buffer_start_offset); 240 | break ReadBufferState::WaitingForMore; 241 | } 242 | log::debug!("skipping message of length {distance}"); 243 | self.receive_buffer_start_offset += distance; 244 | } 245 | }, 246 | } 247 | }; 248 | if self.receive_buffer_start_offset == self.receive_buffer_slice_end 249 | && self.receive_buffer_start_offset != 0 250 | { 251 | log::debug!("read buffer complete - resetting: {self}"); 252 | self.receive_buffer_start_offset = 0; 253 | self.receive_buffer_slice_end = 0; 254 | } 255 | state 256 | } 257 | 258 | /// read from the TcpStream 259 | fn read_from_stream(&mut self) -> ReadBufferState { 260 | match self 261 | .stream 262 | .try_read(&mut self.receive_buffer[self.receive_buffer_slice_end..]) 263 | { 264 | Ok(0) => { 265 | log::info!( 266 | "connection was shut down as recv returned 0. Requested {}", 267 | self.receive_buffer.len() - self.receive_buffer_slice_end 268 | ); 269 | ReadBufferState::Disconnected 270 | } 271 | Ok(bytes_read) => { 272 | self.receive_buffer_slice_end += bytes_read; 273 | ReadBufferState::PartiallyConsumed 274 | } 275 | // Would block "errors" are the OS's way of saying that the 276 | // connection is not actually ready to perform this I/O operation. 277 | Err(ref err) if would_block(err) => { 278 | log::trace!("read everything. No longer readable"); 279 | ReadBufferState::WaitingForMore 280 | } 281 | Err(ref err) if interrupted(err) => { 282 | log::trace!("interrupted, so try again later"); 283 | ReadBufferState::PartiallyConsumed 284 | } 285 | // Other errors we'll consider fatal. 286 | Err(err) => { 287 | log::warn!("error while reading from tcp stream. buffer length: {}b, offset: {}, offered length {}b, err: {err:?}", self.receive_buffer.len(), self.receive_buffer_slice_end, self.receive_buffer.len() - self.receive_buffer_slice_end); 288 | ReadBufferState::Error(err) 289 | } 290 | } 291 | } 292 | 293 | fn room_in_send_buffer(&self) -> usize { 294 | self.max_queued_send_messages - self.send_buffer.len() 295 | } 296 | 297 | /// This serializes work-in-progress messages and moves them over into the write queue 298 | fn poll_serialize_outbound_messages(&mut self, context: &mut Context<'_>) -> Poll<()> { 299 | let max_outbound = self.room_in_send_buffer(); 300 | if max_outbound == 0 { 301 | log::debug!("send is full: {self}"); 302 | // pending on a network status event 303 | return Poll::Pending; 304 | } 305 | 306 | let start_len = self.send_buffer.len(); 307 | for _ in 0..max_outbound { 308 | let message = match self.outbound_message_buffer.pop() { 309 | Some(next) => next, 310 | None => { 311 | match self.outbound_messages.poll_recv_many( 312 | context, 313 | &mut self.outbound_message_buffer, 314 | self.max_queued_send_messages, 315 | ) { 316 | Poll::Ready(count) => { 317 | // ugh, I know. but poll_recv_many is much cheaper than poll_recv, 318 | // and poll_recv requires &mut Vec. Otherwise this would be a VecDeque with no reverse. 319 | self.outbound_message_buffer.reverse(); 320 | match self.outbound_message_buffer.pop() { 321 | Some(next) => next, 322 | None => { 323 | assert_eq!(0, count); 324 | log::info!("outbound message channel was closed"); 325 | return Poll::Ready(()); 326 | } 327 | } 328 | } 329 | Poll::Pending => { 330 | log::trace!("no messages to serialize"); 331 | break; 332 | } 333 | } 334 | } 335 | }; 336 | let mut buffer = self.serializer_buffers.pop().unwrap_or_default(); 337 | self.serializer.encode(message, &mut buffer); 338 | if self.max_buffer_length < buffer.len() { 339 | log::error!( 340 | "tried to send too large a message. Max {}, attempted: {}", 341 | self.max_buffer_length, 342 | buffer.len() 343 | ); 344 | return Poll::Ready(()); 345 | } 346 | log::trace!( 347 | "serialized message and enqueueing outbound buffer: {}b", 348 | buffer.len() 349 | ); 350 | // queue up a writev 351 | self.send_buffer.push_back(buffer); 352 | } 353 | let new_len = self.send_buffer.len(); 354 | if start_len != new_len { 355 | log::debug!( 356 | "serialized {} messages, waking task to look for more input", 357 | new_len - start_len 358 | ); 359 | // if the serializer made progress, there may be more work that the network or outbound channel can do. 360 | // make sure the task gets another round to try. 361 | context.waker().wake_by_ref(); 362 | } 363 | Poll::Pending 364 | } 365 | 366 | /// Send buffers to the tcp stream, and recycle them if they are fully written 367 | fn writev_buffers(&mut self) -> std::result::Result { 368 | /// I need to figure out how to get this from the os rather than hardcoding. 16 is the lowest I've seen mention of, 369 | /// and I've seen 1024 more commonly. 370 | const UIO_MAXIOV: usize = 128; 371 | 372 | let buffers: Vec = self 373 | .send_buffer 374 | .iter() 375 | .take(UIO_MAXIOV) 376 | .map(|v| IoSlice::new(v)) 377 | .collect(); 378 | match self.stream.try_write_vectored(&buffers) { 379 | Ok(0) => { 380 | log::info!("write stream was closed"); 381 | return Ok(true); 382 | } 383 | Ok(written) => { 384 | self.rotate_send_buffers(written); 385 | } 386 | // Would block "errors" are the OS's way of saying that the 387 | // connection is not actually ready to perform this I/O operation. 388 | Err(ref err) if would_block(err) => { 389 | log::trace!("would block - no longer writable"); 390 | } 391 | Err(ref err) if interrupted(err) => { 392 | log::trace!("write interrupted - try again later"); 393 | } 394 | // other errors terminate the stream 395 | Err(err) => { 396 | log::warn!( 397 | "error while writing to tcp stream: {err:?}, buffers: {}, {}b: {:?}", 398 | buffers.len(), 399 | buffers.iter().map(|b| b.len()).sum::(), 400 | buffers.into_iter().map(|b| b.len()).collect::>() 401 | ); 402 | return Err(err); 403 | } 404 | } 405 | Ok(false) 406 | } 407 | 408 | /// Discard all written bytes, and recycle the buffers that are fully written 409 | fn rotate_send_buffers(&mut self, mut written: usize) { 410 | let total_written = written; 411 | while 0 < written { 412 | if let Some(mut front) = self.send_buffer.pop_front() { 413 | if front.len() <= written { 414 | written -= front.len(); 415 | log::trace!( 416 | "recycling buffer of length {}, remaining: {}", 417 | front.len(), 418 | written 419 | ); 420 | 421 | // Reuse the buffer! 422 | // SAFETY: This is purely a buffer, and u8 does not require drop. 423 | unsafe { front.set_len(0) }; 424 | self.serializer_buffers.push(front); 425 | } else { 426 | // Walk the buffer forward through a replacement. It will still amortize the allocation, 427 | // but this is not optimal. It's relatively easier to manage though, and I'm a busy person. 428 | log::debug!("after writing {total_written}b, shifting partially written buffer of {}b by {written}b", front.len()); 429 | let replacement = front[written..].to_vec(); 430 | self.send_buffer.push_front(replacement); 431 | break; 432 | } 433 | } else { 434 | log::error!("rotated all buffers but {written} bytes unaccounted for"); 435 | break; 436 | } 437 | } 438 | } 439 | 440 | fn poll_receive(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Option> { 441 | loop { 442 | break if self.receive_buffer_slice_end < self.max_buffer_length { 443 | match self.stream.poll_read_ready(context) { 444 | Poll::Ready(status) => { 445 | if let Err(e) = status { 446 | log::error!("error while polling read readiness: {e:?}"); 447 | return Some(Poll::Ready(())); 448 | } 449 | 450 | // Step 1a: read raw bytes from the stream 451 | match self.read_inbound() { 452 | ReadBufferState::WaitingForMore => { 453 | log::debug!( 454 | "consumed all that I can from the read stream for now {self}" 455 | ); 456 | continue; 457 | } 458 | ReadBufferState::PartiallyConsumed => { 459 | log::debug!("more to read"); 460 | continue; 461 | } 462 | ReadBufferState::Disconnected => { 463 | log::info!("read connection closed"); 464 | return Some(Poll::Ready(())); 465 | } 466 | ReadBufferState::Error(e) => { 467 | log::warn!("error while reading from tcp stream: {e:?}"); 468 | return Some(Poll::Ready(())); 469 | } 470 | } 471 | } 472 | Poll::Pending => { 473 | log::trace!("read side is up to date"); 474 | } 475 | } 476 | } else { 477 | log::debug!("receive buffer is full"); 478 | }; 479 | } 480 | None 481 | } 482 | 483 | fn poll_deserialize(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Option> { 484 | if self.receive_buffer_start_offset == self.receive_buffer_slice_end { 485 | // Only when the read buffer is empty should deserialization skip waking the task. 486 | // I always want the task to be pending on network readability, network writability, or the outbound message channel. 487 | // Otherwise, I want the task to be woken. 488 | return None; 489 | } 490 | let start_len = self.inbound_messages.len(); 491 | match self.read_inbound_messages_into_read_queue() { 492 | ReadBufferState::WaitingForMore => { 493 | log::trace!("read queue is still open"); 494 | let new_len = self.inbound_messages.len(); 495 | if start_len != new_len { 496 | // If you don't wake here, then when the read side gets full, you won't be registered for network 497 | // activity wakes. 498 | // You can skip this wake when you are deserializing a large message, because that is waiting 499 | // for more inbound data to arrive. You are only WaitingForMore when you're mid-message, 500 | // and if you've just got the one message you can wait for the network. 501 | // There's no deserializer progress to report if you didn't deserialize anything. 502 | context.waker().wake_by_ref(); 503 | } 504 | } 505 | ReadBufferState::PartiallyConsumed => { 506 | log::debug!("read buffer partially consumed for responsiveness {self}"); 507 | context.waker().wake_by_ref(); 508 | } 509 | ReadBufferState::Disconnected => { 510 | log::info!("read queue closed"); 511 | return Some(Poll::Ready(())); 512 | } 513 | ReadBufferState::Error(e) => { 514 | log::warn!("error while reading from buffer: {e:?}"); 515 | return Some(Poll::Ready(())); 516 | } 517 | } 518 | None 519 | } 520 | 521 | fn poll_react(&mut self) -> Option> { 522 | let Self { 523 | reactor, 524 | inbound_messages, 525 | .. 526 | } = &mut *self; 527 | if reactor.on_inbound_messages(inbound_messages.drain(..)) == ReactorStatus::Disconnect { 528 | log::debug!("reactor requested disconnect"); 529 | return Some(Poll::Ready(())); 530 | } 531 | None 532 | } 533 | 534 | fn poll_outbound(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Option> { 535 | loop { 536 | break if !self.send_buffer.is_empty() { 537 | // the write half of the stream is only considered for readiness when there is something to write. 538 | match self.stream.poll_write_ready(context) { 539 | Poll::Ready(status) => { 540 | if let Err(e) = status { 541 | log::error!("error while polling write readiness: {e:?}"); 542 | return Some(Poll::Ready(())); 543 | } 544 | 545 | // Step 5a: write raw bytes to the stream 546 | log::debug!("writing {self}"); 547 | match self.writev_buffers() { 548 | Ok(true) => { 549 | log::info!("write connection closed"); 550 | return Some(Poll::Ready(())); 551 | } 552 | Ok(false) => { 553 | log::debug!("wrote {self}"); 554 | } 555 | Err(e) => { 556 | log::warn!("error while writing to tcp stream: {e:?}"); 557 | return Some(Poll::Ready(())); 558 | } 559 | } 560 | 561 | log::trace!( 562 | "wrote output, checking for more and possibly registering for wake" 563 | ); 564 | continue; 565 | } 566 | Poll::Pending => { 567 | log::debug!("waiting for outbound wake"); 568 | } 569 | } 570 | } else { 571 | log::trace!("nothing to send"); 572 | }; 573 | } 574 | None 575 | } 576 | } 577 | -------------------------------------------------------------------------------- /protosocket-connection/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Low-level connection types for protosocket. 2 | //! 3 | //! This is the core of protosocket, providing the `Connection` 4 | //! type. This type is used to create both client and server 5 | //! channels. 6 | //! Normally you will use `Connection` via the protosocket-prost 7 | //! or protosocket-server crates. 8 | 9 | mod connection; 10 | mod serde; 11 | mod types; 12 | 13 | pub use connection::Connection; 14 | pub use types::ConnectionBindings; 15 | pub use types::DeserializeError; 16 | pub use types::Deserializer; 17 | pub use types::MessageReactor; 18 | pub use types::ReactorStatus; 19 | pub use types::Serializer; 20 | 21 | pub(crate) fn interrupted(err: &std::io::Error) -> bool { 22 | err.kind() == std::io::ErrorKind::Interrupted 23 | } 24 | 25 | pub(crate) fn would_block(err: &std::io::Error) -> bool { 26 | err.kind() == std::io::ErrorKind::WouldBlock 27 | } 28 | -------------------------------------------------------------------------------- /protosocket-connection/src/serde.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /protosocket-connection/src/types.rs: -------------------------------------------------------------------------------- 1 | /// A serializer takes messages and produces outbound bytes. 2 | pub trait Serializer: Unpin + Send { 3 | /// The message type consumed by this serializer. 4 | type Message: Send; 5 | 6 | /// Encode a message into a buffer. 7 | fn encode(&mut self, response: Self::Message, buffer: &mut Vec); 8 | } 9 | 10 | /// A deserializer takes inbound bytes and produces messages. 11 | pub trait Deserializer: Unpin + Send { 12 | /// The message type produced by this deserializer. 13 | type Message: Send; 14 | 15 | /// Decode a message from the buffer, or tell why you can't. 16 | /// 17 | /// You must not consume more bytes than the message you produce. 18 | fn decode( 19 | &mut self, 20 | buffer: impl bytes::Buf, 21 | ) -> std::result::Result<(usize, Self::Message), DeserializeError>; 22 | } 23 | 24 | /// Errors that can occur when deserializing a message. 25 | #[derive(Debug, thiserror::Error)] 26 | pub enum DeserializeError { 27 | /// Buffer will be retained and you will be called again later with more bytes 28 | #[error("Need more bytes to decode the next message")] 29 | IncompleteBuffer { 30 | /// This is a hint to the connection for how many more bytes should be read. 31 | /// You may be called again before you get another buffer with at least this 32 | /// many bytes. 33 | next_message_size: usize, 34 | }, 35 | /// Buffer will be discarded 36 | #[error("Bad buffer")] 37 | InvalidBuffer, 38 | /// distance will be skipped 39 | #[error("Skip message")] 40 | SkipMessage { 41 | /// If a message is not to be serviced, you can skip it. This is how many 42 | /// bytes will be skipped. 43 | /// You may be called again before this message is skipped and you may need 44 | /// to repeat the skip. 45 | distance: usize, 46 | }, 47 | } 48 | 49 | /// A message reactor is a stateful object that processes inbound messages. 50 | /// You receive &mut self, and you receive your messages by value. 51 | /// 52 | /// A message reactor may be a server which spawns a task per message, or a client which 53 | /// matches response ids to a HashMap of concurrent requests with oneshot completions. 54 | /// 55 | /// Your message reactor and your tcp connection share their fate - when one drops or 56 | /// disconnects, the other does too. 57 | pub trait MessageReactor: Unpin + Send + 'static { 58 | type Inbound; 59 | 60 | /// Called from the connection's driver task when messages are received. 61 | /// 62 | /// You must take all of the messages quickly: Blocking here will block the connection. 63 | /// If you can't accept new messages and you can't queue, you should consider returning 64 | /// Disconnect. 65 | fn on_inbound_messages( 66 | &mut self, 67 | messages: impl IntoIterator, 68 | ) -> ReactorStatus; 69 | } 70 | 71 | /// What the connection should do after processing a batch of inbound messages. 72 | #[derive(Debug, PartialEq, Eq)] 73 | pub enum ReactorStatus { 74 | /// Continue processing messages. 75 | Continue, 76 | /// Disconnect the tcp connection. 77 | Disconnect, 78 | } 79 | 80 | /// Define the types for a Connection. 81 | /// 82 | /// A protosocket uses only 1 kind of message per port. This is a constraint to keep types 83 | /// straightforward. If you want multiple message types, you should consider using protocol 84 | /// buffers `oneof` fields. You would use a wrapper type to hold the oneof and any additional 85 | /// metadata, like a request ID or trace id. 86 | pub trait ConnectionBindings: 'static { 87 | /// The deserializer for this connection. 88 | type Deserializer: Deserializer; 89 | /// The serializer for this connection. 90 | type Serializer: Serializer; 91 | /// The message reactor for this connection. 92 | type Reactor: MessageReactor::Message>; 93 | } 94 | -------------------------------------------------------------------------------- /protosocket-messagepack/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "protosocket-messagepack" 3 | description = "Message-oriented nonblocking tcp stream - messagepack serde bindings" 4 | version.workspace = true 5 | authors.workspace = true 6 | repository.workspace = true 7 | edition.workspace = true 8 | license.workspace = true 9 | readme.workspace = true 10 | keywords.workspace = true 11 | categories.workspace = true 12 | 13 | [dependencies] 14 | protosocket = { workspace = true } 15 | 16 | bytes = { workspace = true } 17 | log = { workspace = true } 18 | serde = { workspace = true } 19 | rmp = { workspace = true } 20 | rmp-serde = { workspace = true } 21 | -------------------------------------------------------------------------------- /protosocket-messagepack/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Read, marker::PhantomData}; 2 | 3 | #[derive(Debug)] 4 | pub struct ProtosocketMessagePackSerializer { 5 | _phantom: std::marker::PhantomData, 6 | } 7 | 8 | impl Default for ProtosocketMessagePackSerializer { 9 | fn default() -> Self { 10 | Self { 11 | _phantom: PhantomData, 12 | } 13 | } 14 | } 15 | 16 | impl protosocket::Serializer for ProtosocketMessagePackSerializer 17 | where 18 | T: serde::Serialize + Send + Unpin + std::fmt::Debug, 19 | { 20 | type Message = T; 21 | 22 | fn encode(&mut self, message: Self::Message, buffer: &mut Vec) { 23 | log::debug!("encoding {message:?}"); 24 | // reserve length prefix 25 | buffer.extend_from_slice(&[0; 5]); 26 | rmp_serde::encode::write(buffer, &message).expect("messages must be encodable"); 27 | let len = buffer.len(); 28 | unsafe { 29 | buffer.set_len(0); 30 | } 31 | rmp::encode::write_u32(buffer, len as u32 - 5).expect("message length is encodable"); 32 | unsafe { 33 | buffer.set_len(len); 34 | } 35 | } 36 | } 37 | 38 | #[derive(Debug)] 39 | pub struct ProtosocketMessagePackDeserializer { 40 | _phantom: std::marker::PhantomData, 41 | state: State, 42 | } 43 | 44 | impl Default for ProtosocketMessagePackDeserializer { 45 | fn default() -> Self { 46 | Self { 47 | _phantom: PhantomData, 48 | state: Default::default(), 49 | } 50 | } 51 | } 52 | 53 | #[derive(Debug, Default, Copy, Clone)] 54 | enum State { 55 | #[default] 56 | Waiting, 57 | ReadingLength(u32), 58 | } 59 | 60 | impl protosocket::Deserializer for ProtosocketMessagePackDeserializer 61 | where 62 | T: serde::de::DeserializeOwned + Send + Unpin + std::fmt::Debug, 63 | { 64 | type Message = T; 65 | 66 | fn decode( 67 | &mut self, 68 | buffer: impl bytes::Buf, 69 | ) -> std::result::Result<(usize, Self::Message), protosocket::DeserializeError> { 70 | let start_remaining = buffer.remaining(); 71 | let mut reader = buffer.reader(); 72 | let length = match self.state { 73 | State::Waiting => { 74 | // 1 byte for the number tag, 4 bytes for the message length 75 | if start_remaining < 5 { 76 | return Err(protosocket::DeserializeError::IncompleteBuffer { 77 | next_message_size: 5, 78 | }); 79 | } 80 | let length: u32 = match rmp::decode::read_u32(&mut reader) { 81 | Ok(length) => length, 82 | Err(e) => { 83 | log::error!("decode length error: {e:?}"); 84 | return Err(protosocket::DeserializeError::InvalidBuffer); 85 | } 86 | }; 87 | self.state = State::ReadingLength(length); 88 | length 89 | } 90 | State::ReadingLength(length) => { 91 | let _ = reader.read(&mut [0; 5]).expect("skip parsing"); 92 | length 93 | } 94 | }; 95 | if start_remaining < (length + 5) as usize { 96 | return Err(protosocket::DeserializeError::IncompleteBuffer { 97 | next_message_size: (length + 5) as usize, 98 | }); 99 | } 100 | self.state = State::Waiting; 101 | 102 | rmp_serde::decode::from_read(&mut reader) 103 | .map_err(|e| { 104 | log::error!("decode error length {length}: {e:?}"); 105 | protosocket::DeserializeError::InvalidBuffer 106 | }) 107 | .map(|message| { 108 | let buffer = reader.into_inner(); 109 | let length = start_remaining - buffer.remaining(); 110 | log::debug!("decoded {length}: {message:?}"); 111 | (length, message) 112 | }) 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /protosocket-prost/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "protosocket-prost" 3 | description = "Message-oriented nonblocking tcp stream - protocol buffers bindings" 4 | version.workspace = true 5 | edition.workspace = true 6 | license.workspace = true 7 | authors.workspace = true 8 | readme.workspace = true 9 | repository.workspace = true 10 | keywords.workspace = true 11 | categories.workspace = true 12 | 13 | [dependencies] 14 | protosocket = { workspace = true } 15 | 16 | bytes = { workspace = true } 17 | log = { workspace = true } 18 | prost = { workspace = true } 19 | thiserror = { workspace = true } 20 | tokio = { workspace = true, features = ["rt"] } 21 | -------------------------------------------------------------------------------- /protosocket-prost/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | /// Result type for protosocket-prost. 4 | pub type Result = std::result::Result; 5 | 6 | /// Error type for protosocket-prost. 7 | #[derive(Clone, Debug, thiserror::Error)] 8 | pub enum Error { 9 | #[error("IO failure: {0}")] 10 | IoFailure(#[from] Arc), 11 | #[error("Bad address: {0}")] 12 | AddressError(#[from] core::net::AddrParseError), 13 | #[error("Requested resource was unable to respond: ({0})")] 14 | Dead(&'static str), 15 | } 16 | -------------------------------------------------------------------------------- /protosocket-prost/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Conveniences for using protocol buffers via `prost` with `protosocket`. 2 | //! 3 | //! See the example-proto directory for a complete example of how to use this crate. 4 | 5 | mod error; 6 | mod prost_client_registry; 7 | mod prost_serializer; 8 | mod prost_socket; 9 | 10 | pub use error::{Error, Result}; 11 | pub use prost_client_registry::ClientRegistry; 12 | pub use prost_serializer::ProstSerializer; 13 | pub use prost_socket::ProstClientConnectionBindings; 14 | pub use prost_socket::ProstServerConnectionBindings; 15 | -------------------------------------------------------------------------------- /protosocket-prost/src/prost_client_registry.rs: -------------------------------------------------------------------------------- 1 | use protosocket::{Connection, MessageReactor}; 2 | use tokio::{net::TcpStream, sync::mpsc}; 3 | 4 | use crate::{ProstClientConnectionBindings, ProstSerializer}; 5 | 6 | /// A factory for creating client connections to a `protosocket` server. 7 | #[derive(Debug, Clone)] 8 | pub struct ClientRegistry { 9 | max_buffer_length: usize, 10 | max_queued_outbound_messages: usize, 11 | runtime: tokio::runtime::Handle, 12 | } 13 | 14 | impl ClientRegistry { 15 | /// Construct a new client registry. Connections will be spawned on the provided runtime. 16 | pub fn new(runtime: tokio::runtime::Handle) -> Self { 17 | log::trace!("new client registry"); 18 | Self { 19 | max_buffer_length: 4 * (2 << 20), 20 | max_queued_outbound_messages: 256, 21 | runtime, 22 | } 23 | } 24 | 25 | /// Sets the maximum read buffer length for connections created by this registry after 26 | /// the setting is applied. 27 | pub fn set_max_read_buffer_length(&mut self, max_buffer_length: usize) { 28 | self.max_buffer_length = max_buffer_length; 29 | } 30 | 31 | /// Sets the maximum queued outbound messages for connections created by this registry after 32 | /// the setting is applied. 33 | pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) { 34 | self.max_queued_outbound_messages = max_queued_outbound_messages; 35 | } 36 | 37 | /// Get a new connection to a `protosocket` server. 38 | pub async fn register_client( 39 | &self, 40 | address: impl Into, 41 | message_reactor: Reactor, 42 | ) -> crate::Result> 43 | where 44 | Request: prost::Message + Default + Unpin + 'static, 45 | Response: prost::Message + Default + Unpin + 'static, 46 | Reactor: MessageReactor, 47 | { 48 | let address = address.into().parse()?; 49 | let stream = TcpStream::connect(address) 50 | .await 51 | .map_err(std::sync::Arc::new)?; 52 | stream.set_nodelay(true).map_err(std::sync::Arc::new)?; 53 | let (outbound, outbound_messages) = mpsc::channel(self.max_queued_outbound_messages); 54 | let connection = 55 | Connection::>::new( 56 | stream, 57 | address, 58 | ProstSerializer::default(), 59 | ProstSerializer::default(), 60 | self.max_buffer_length, 61 | self.max_queued_outbound_messages, 62 | outbound_messages, 63 | message_reactor, 64 | ); 65 | self.runtime.spawn(connection); 66 | Ok(outbound) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /protosocket-prost/src/prost_serializer.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use protosocket::{DeserializeError, Deserializer, Serializer}; 4 | 5 | /// A stateless implementation of protosocket's `Serializer` and `Deserializer` 6 | /// traits using `prost` for encoding and decoding protocol buffers messages. 7 | #[derive(Default, Debug)] 8 | pub struct ProstSerializer { 9 | pub(crate) _phantom: PhantomData<(Deserialized, Serialized)>, 10 | } 11 | 12 | impl Serializer for ProstSerializer 13 | where 14 | Deserialized: prost::Message + Default + Unpin, 15 | Serialized: prost::Message + Unpin, 16 | { 17 | type Message = Serialized; 18 | 19 | fn encode(&mut self, message: Self::Message, buffer: &mut Vec) { 20 | match message.encode_length_delimited(buffer) { 21 | Ok(_) => { 22 | log::debug!("encoded {message:?}"); 23 | } 24 | Err(e) => { 25 | log::error!("encoding error: {e:?}"); 26 | } 27 | } 28 | } 29 | } 30 | impl Deserializer for ProstSerializer 31 | where 32 | Deserialized: prost::Message + Default + Unpin, 33 | Serialized: prost::Message + Unpin, 34 | { 35 | type Message = Deserialized; 36 | 37 | fn decode( 38 | &mut self, 39 | mut buffer: impl bytes::Buf, 40 | ) -> std::result::Result<(usize, Self::Message), DeserializeError> { 41 | match prost::decode_length_delimiter(buffer.chunk()) { 42 | Ok(message_length) => { 43 | if buffer.remaining() < message_length + prost::length_delimiter_len(message_length) 44 | { 45 | return Err(DeserializeError::IncompleteBuffer { 46 | next_message_size: message_length, 47 | }); 48 | } 49 | } 50 | Err(e) => { 51 | log::trace!("can't read a length delimiter {e:?}"); 52 | return Err(DeserializeError::IncompleteBuffer { 53 | next_message_size: 10, 54 | }); 55 | } 56 | }; 57 | 58 | let start = buffer.remaining(); 59 | match ::decode_length_delimited(&mut buffer) { 60 | Ok(message) => { 61 | let length = start - buffer.remaining(); 62 | log::debug!("decoded {length}: {message:?}"); 63 | Ok((length, message)) 64 | } 65 | Err(e) => { 66 | log::warn!("could not decode message: {e:?}"); 67 | Err(DeserializeError::InvalidBuffer) 68 | } 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /protosocket-prost/src/prost_socket.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use protosocket::{ConnectionBindings, MessageReactor}; 4 | 5 | use crate::prost_serializer::ProstSerializer; 6 | 7 | /// A convenience type for binding a `ProstSerializer` to a server-side 8 | /// `protosocket::Connection`. 9 | pub struct ProstServerConnectionBindings { 10 | _phantom: PhantomData<(Request, Response, Reactor)>, 11 | } 12 | 13 | impl ConnectionBindings 14 | for ProstServerConnectionBindings 15 | where 16 | Request: prost::Message + Default + Unpin + 'static, 17 | Response: prost::Message + Unpin + 'static, 18 | Reactor: MessageReactor, 19 | { 20 | type Deserializer = ProstSerializer; 21 | type Serializer = ProstSerializer; 22 | type Reactor = Reactor; 23 | } 24 | 25 | /// A convenience type for binding a `ProstSerializer` to a client-side 26 | /// `protosocket::Connection`. 27 | pub struct ProstClientConnectionBindings { 28 | _phantom: PhantomData<(Request, Response, Reactor)>, 29 | } 30 | 31 | impl ConnectionBindings 32 | for ProstClientConnectionBindings 33 | where 34 | Request: prost::Message + Default + Unpin + 'static, 35 | Response: prost::Message + Default + Unpin + 'static, 36 | Reactor: MessageReactor, 37 | { 38 | type Deserializer = ProstSerializer; 39 | type Serializer = ProstSerializer; 40 | type Reactor = Reactor; 41 | } 42 | -------------------------------------------------------------------------------- /protosocket-rpc/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "protosocket-rpc" 3 | description = "RPC using protosockets" 4 | version.workspace = true 5 | authors.workspace = true 6 | repository.workspace = true 7 | edition.workspace = true 8 | license.workspace = true 9 | readme.workspace = true 10 | keywords.workspace = true 11 | categories.workspace = true 12 | 13 | [dependencies] 14 | protosocket = { workspace = true } 15 | 16 | futures = { workspace = true } 17 | k-lock = { workspace = true } 18 | log = { workspace = true } 19 | tokio = { workspace = true } 20 | tokio-util = { workspace = true } 21 | thiserror = { workspace = true } 22 | 23 | [dev-dependencies] 24 | prost = { workspace = true } 25 | -------------------------------------------------------------------------------- /protosocket-rpc/README.md: -------------------------------------------------------------------------------- 1 | # protosocket-rpc 2 | For making RPC servers and clients. 3 | 4 | A protosocket rpc server consists of a couple key traits: 5 | 6 | * `SocketService`: Your service that takes new connections and produces `ConnectionService`s. 7 | * `ConnectionService`: Your service that manages a connection, creating new RPC futures and doing bookkeeping. 8 | * `Message`: Your way to get protosocket metadata out of your encoded messages. 9 | 10 | A protosocket rpc client is a little more basic, just relying on common protosocket traits and `Message`. 11 | 12 | Protosocket rpc lets you choose any encoding and does not wrap your messages at all. The bytes you 13 | send are the bytes which are sent. This means you need to provide a way to communicate the basic protosocket 14 | metadata on each message: A message_id u64 and a control code u8. The `Message` trait helps to ensure you get 15 | the needful functions wired through. 16 | -------------------------------------------------------------------------------- /protosocket-rpc/src/client/configuration.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | 3 | use protosocket::Connection; 4 | use tokio::sync::mpsc; 5 | 6 | use crate::{ 7 | client::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor}, 8 | Message, 9 | }; 10 | 11 | use super::{reactor::completion_reactor::RpcCompletionConnectionBindings, RpcClient}; 12 | 13 | /// Configuration for a `protosocket` rpc client. 14 | #[derive(Debug, Clone)] 15 | pub struct Configuration { 16 | max_buffer_length: usize, 17 | max_queued_outbound_messages: usize, 18 | } 19 | 20 | impl Default for Configuration { 21 | fn default() -> Self { 22 | Self { 23 | max_buffer_length: 4 * (2 << 20), 24 | max_queued_outbound_messages: 256, 25 | } 26 | } 27 | } 28 | 29 | impl Configuration { 30 | /// Max buffer length limits the max message size. Try to use a buffer length that is at least 4 times the largest message you want to support. 31 | /// 32 | /// Default: 4MiB 33 | pub fn max_buffer_length(&mut self, max_buffer_length: usize) { 34 | self.max_buffer_length = max_buffer_length; 35 | } 36 | 37 | /// Max messages that will be queued up waiting for send on the client channel. 38 | /// 39 | /// Default: 256 40 | pub fn max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) { 41 | self.max_queued_outbound_messages = max_queued_outbound_messages; 42 | } 43 | } 44 | 45 | /// Connect a new protosocket rpc client to a server 46 | pub async fn connect( 47 | address: SocketAddr, 48 | configuration: &Configuration, 49 | ) -> Result< 50 | ( 51 | RpcClient, 52 | protosocket::Connection>, 53 | ), 54 | crate::Error, 55 | > 56 | where 57 | Deserializer: protosocket::Deserializer + Default + 'static, 58 | Serializer: protosocket::Serializer + Default + 'static, 59 | Deserializer::Message: Message, 60 | Serializer::Message: Message, 61 | { 62 | log::trace!("new client {address}, {configuration:?}"); 63 | 64 | let stream = tokio::net::TcpStream::connect(address).await?; 65 | stream.set_nodelay(true)?; 66 | 67 | let message_reactor: RpcCompletionReactor< 68 | Deserializer::Message, 69 | DoNothingMessageHandler, 70 | > = RpcCompletionReactor::new(Default::default()); 71 | let (outbound, outbound_messages) = mpsc::channel(configuration.max_queued_outbound_messages); 72 | let rpc_client = RpcClient::new(outbound, &message_reactor); 73 | 74 | // Tie outbound_messages to message_reactor via a protosocket::Connection 75 | let connection = Connection::>::new( 76 | stream, 77 | address, 78 | Deserializer::default(), 79 | Serializer::default(), 80 | configuration.max_buffer_length, 81 | configuration.max_queued_outbound_messages, 82 | outbound_messages, 83 | message_reactor, 84 | ); 85 | 86 | Ok((rpc_client, connection)) 87 | } 88 | -------------------------------------------------------------------------------- /protosocket-rpc/src/client/mod.rs: -------------------------------------------------------------------------------- 1 | mod configuration; 2 | mod reactor; 3 | mod rpc_client; 4 | 5 | pub use configuration::{connect, Configuration}; 6 | pub use rpc_client::RpcClient; 7 | 8 | pub use reactor::completion_streaming::StreamingCompletion; 9 | pub use reactor::completion_unary::UnaryCompletion; 10 | -------------------------------------------------------------------------------- /protosocket-rpc/src/client/reactor/completion_reactor.rs: -------------------------------------------------------------------------------- 1 | use core::panic; 2 | use std::{ 3 | collections::hash_map::Entry, 4 | marker::PhantomData, 5 | sync::{atomic::AtomicBool, Arc}, 6 | }; 7 | 8 | use protosocket::{ConnectionBindings, MessageReactor, ReactorStatus}; 9 | 10 | use crate::{message::ProtosocketControlCode, Message}; 11 | 12 | use super::completion_registry::{Completion, CompletionRegistry, RpcRegistrar}; 13 | 14 | pub struct RpcCompletionConnectionBindings( 15 | PhantomData<(Serializer, Deserializer)>, 16 | ); 17 | impl ConnectionBindings 18 | for RpcCompletionConnectionBindings 19 | where 20 | Serializer: protosocket::Serializer + 'static, 21 | Serializer::Message: Message, 22 | Deserializer: protosocket::Deserializer + 'static, 23 | Deserializer::Message: Message, 24 | { 25 | type Deserializer = Deserializer; 26 | type Serializer = Serializer; 27 | type Reactor = 28 | RpcCompletionReactor>; 29 | } 30 | 31 | #[derive(Debug)] 32 | pub struct RpcCompletionReactor 33 | where 34 | Inbound: Message, 35 | TUnregisteredMessageHandler: UnregisteredMessageHandler, 36 | { 37 | rpc_registry: CompletionRegistry, 38 | is_alive: Arc, 39 | unregistered_message_handler: TUnregisteredMessageHandler, 40 | } 41 | impl 42 | RpcCompletionReactor 43 | where 44 | Inbound: Message, 45 | TUnregisteredMessageHandler: UnregisteredMessageHandler, 46 | { 47 | #[allow(clippy::new_without_default)] 48 | pub fn new(unregistered_message_handler: TUnregisteredMessageHandler) -> Self { 49 | Self { 50 | rpc_registry: CompletionRegistry::new(), 51 | is_alive: Arc::new(AtomicBool::new(true)), 52 | unregistered_message_handler, 53 | } 54 | } 55 | 56 | pub fn alive_handle(&self) -> Arc { 57 | self.is_alive.clone() 58 | } 59 | 60 | pub fn in_flight_submission_handle(&self) -> RpcRegistrar { 61 | self.rpc_registry.in_flight_submission_handle() 62 | } 63 | } 64 | 65 | impl Drop 66 | for RpcCompletionReactor 67 | where 68 | Inbound: Message, 69 | TUnregisteredMessageHandler: UnregisteredMessageHandler, 70 | { 71 | fn drop(&mut self) { 72 | self.is_alive 73 | .store(false, std::sync::atomic::Ordering::Release); 74 | } 75 | } 76 | 77 | impl MessageReactor 78 | for RpcCompletionReactor 79 | where 80 | Inbound: Message, 81 | TUnregisteredMessageHandler: UnregisteredMessageHandler, 82 | { 83 | type Inbound = Inbound; 84 | 85 | fn on_inbound_messages( 86 | &mut self, 87 | messages: impl IntoIterator, 88 | ) -> ReactorStatus { 89 | self.rpc_registry.take_new_rpc_lifecycle_actions(); 90 | 91 | for message_from_the_network in messages.into_iter() { 92 | let message_id_from_the_network = message_from_the_network.message_id(); 93 | match message_from_the_network.control_code() { 94 | ProtosocketControlCode::Normal => (), 95 | ProtosocketControlCode::Cancel => { 96 | log::debug!("{message_id_from_the_network} cancelling command"); 97 | self.rpc_registry.deregister(message_id_from_the_network); 98 | continue; 99 | } 100 | ProtosocketControlCode::End => { 101 | log::debug!("{message_id_from_the_network} command end of stream"); 102 | self.rpc_registry.deregister(message_id_from_the_network); 103 | continue; 104 | } 105 | } 106 | match self.rpc_registry.entry(message_id_from_the_network) { 107 | Entry::Occupied(mut registered_rpc) => { 108 | if let Completion::RemoteStreaming(stream) = registered_rpc.get_mut() { 109 | if let Err(e) = stream.send(message_from_the_network) { 110 | log::debug!("{message_id_from_the_network} completion channel closed - did the client lose interest in this request? {e:?}"); 111 | registered_rpc.remove(); 112 | } 113 | } else if let Completion::Unary(completion) = registered_rpc.remove() { 114 | if let Err(e) = completion.send(Ok(message_from_the_network)) { 115 | log::debug!("{message_id_from_the_network} completion channel closed - did the client lose interest in this request? {e:?}"); 116 | } 117 | } else { 118 | panic!("{message_id_from_the_network} unexpected command response type. Sorry, I wanted to borrow for streaming and remove by value for unary without doing 2 map lookups, so I couldn't match"); 119 | } 120 | } 121 | Entry::Vacant(_vacant_entry) => { 122 | // Possibly a cancelled response if this is a client, and probably a new rpc if it's a server 123 | log::debug!( 124 | "{message_id_from_the_network} command response for command that was not in flight" 125 | ); 126 | self.unregistered_message_handler 127 | .on_message(message_from_the_network, &mut self.rpc_registry); 128 | } 129 | } 130 | } 131 | ReactorStatus::Continue 132 | } 133 | } 134 | 135 | pub trait UnregisteredMessageHandler: Send + Unpin + 'static { 136 | type Inbound: Message; 137 | 138 | fn on_message( 139 | &mut self, 140 | message: Self::Inbound, 141 | rpc_registry: &mut CompletionRegistry, 142 | ); 143 | } 144 | 145 | #[derive(Debug)] 146 | pub struct DoNothingMessageHandler { 147 | _phantom: PhantomData, 148 | } 149 | impl UnregisteredMessageHandler for DoNothingMessageHandler { 150 | type Inbound = T; 151 | 152 | fn on_message( 153 | &mut self, 154 | _message: Self::Inbound, 155 | _rpc_registry: &mut CompletionRegistry, 156 | ) { 157 | } 158 | } 159 | 160 | impl Default for DoNothingMessageHandler { 161 | fn default() -> Self { 162 | Self { 163 | _phantom: PhantomData, 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /protosocket-rpc/src/client/reactor/completion_registry.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::{hash_map::Entry, HashMap}, 3 | sync::Arc, 4 | }; 5 | 6 | use k_lock::Mutex; 7 | use tokio::sync::{mpsc, oneshot}; 8 | 9 | use crate::Message; 10 | 11 | #[derive(Debug, Default)] 12 | pub struct CompletionRegistry 13 | where 14 | Inbound: Message, 15 | { 16 | #[allow(clippy::type_complexity)] 17 | in_flight_submission: Arc>>>, 18 | in_flight_buffer: HashMap>, 19 | in_flight: HashMap>, 20 | } 21 | 22 | impl CompletionRegistry 23 | where 24 | Inbound: Message, 25 | { 26 | pub fn new() -> Self { 27 | Self { 28 | in_flight_submission: Default::default(), 29 | in_flight_buffer: Default::default(), 30 | in_flight: Default::default(), 31 | } 32 | } 33 | 34 | pub fn in_flight_submission_handle(&self) -> RpcRegistrar { 35 | RpcRegistrar { 36 | in_flight_submission: self.in_flight_submission.clone(), 37 | } 38 | } 39 | 40 | /// Returns a list of cancelled commands 41 | pub fn take_new_rpc_lifecycle_actions(&mut self) { 42 | { 43 | let mut in_flight_submission = self 44 | .in_flight_submission 45 | .lock() 46 | .expect("brief internal mutex must work"); 47 | if in_flight_submission.is_empty() { 48 | return; 49 | } 50 | // only lock for the swap - this makes sure every time the in_flight_buffer is 51 | // used, it's for O(1) time. 52 | std::mem::swap(&mut self.in_flight_buffer, &mut *in_flight_submission); 53 | } 54 | for (command_id, completion_state) in self.in_flight_buffer.drain() { 55 | match completion_state { 56 | CompletionState::InProgress(completion) => { 57 | self.in_flight.insert(command_id, completion); 58 | } 59 | CompletionState::Done => { 60 | log::debug!("{command_id} command done"); 61 | self.in_flight.remove(&command_id); 62 | } 63 | } 64 | } 65 | } 66 | 67 | pub fn deregister(&mut self, message_id: u64) { 68 | self.in_flight.remove(&message_id); 69 | } 70 | 71 | pub fn entry(&mut self, message_id: u64) -> Entry<'_, u64, Completion> { 72 | self.in_flight.entry(message_id) 73 | } 74 | } 75 | 76 | /// For removing a tracked rpc from the in-flight map when it is no longer needed 77 | #[derive(Debug)] 78 | pub struct CompletionGuard 79 | where 80 | Inbound: Message, 81 | Outbound: Message, 82 | { 83 | closed: bool, 84 | in_flight_submission: Arc>>>, 85 | message_id: u64, 86 | raw_submission_queue: tokio::sync::mpsc::Sender, 87 | } 88 | 89 | impl CompletionGuard 90 | where 91 | Inbound: Message, 92 | Outbound: Message, 93 | { 94 | pub fn set_closed(&mut self) { 95 | self.closed = true; 96 | } 97 | } 98 | 99 | impl Drop for CompletionGuard 100 | where 101 | Inbound: Message, 102 | Outbound: Message, 103 | { 104 | fn drop(&mut self) { 105 | self.in_flight_submission 106 | .lock() 107 | .expect("brief internal mutex must work") 108 | // This doesn't result in a prompt wake of the reactor 109 | .insert(self.message_id, CompletionState::Done); 110 | if !self.closed { 111 | if let Err(e) = self 112 | .raw_submission_queue 113 | .try_send(Outbound::cancelled(self.message_id)) 114 | { 115 | log::error!( 116 | "unable to send cancellation for message - this will abandon server rpcs {e:?}" 117 | ); 118 | } 119 | } 120 | } 121 | } 122 | 123 | #[derive(Debug, Clone)] 124 | pub struct RpcRegistrar 125 | where 126 | Inbound: Message, 127 | { 128 | in_flight_submission: Arc>>>, 129 | } 130 | 131 | impl RpcRegistrar 132 | where 133 | Inbound: Message, 134 | { 135 | // The triple-buffered message queue mutex is carefully controlled - it can't panic unless the memory allocator panics. 136 | // Probably the server should crash if that happens. 137 | // Note that this is just for tracking - you have to register the completion before sending the message, or else you might 138 | // miss the completion. 139 | #[allow(clippy::expect_used)] 140 | #[must_use] 141 | pub fn register_completion( 142 | &self, 143 | message_id: u64, 144 | completion: Completion, 145 | raw_submission_queue: tokio::sync::mpsc::Sender, 146 | ) -> CompletionGuard { 147 | self.in_flight_submission 148 | .lock() 149 | .expect("brief internal mutex must work") 150 | .insert(message_id, CompletionState::InProgress(completion)); 151 | CompletionGuard { 152 | in_flight_submission: self.in_flight_submission.clone(), 153 | message_id, 154 | closed: false, 155 | raw_submission_queue, 156 | } 157 | } 158 | } 159 | 160 | #[derive(Debug)] 161 | pub enum Completion 162 | where 163 | Inbound: Message, 164 | { 165 | Unary(oneshot::Sender>), 166 | RemoteStreaming(mpsc::UnboundedSender), 167 | } 168 | 169 | #[derive(Debug)] 170 | pub enum CompletionState 171 | where 172 | Inbound: Message, 173 | { 174 | InProgress(Completion), 175 | Done, 176 | } 177 | -------------------------------------------------------------------------------- /protosocket-rpc/src/client/reactor/completion_streaming.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | pin::{pin, Pin}, 3 | task::{Context, Poll}, 4 | }; 5 | 6 | use tokio::sync::mpsc; 7 | 8 | use super::completion_registry::CompletionGuard; 9 | use crate::Message; 10 | 11 | /// A completion for a streaming RPC. 12 | /// 13 | /// Make sure you process this stream quickly, and drop data yourself if you have to. The 14 | /// server will send data as quickly as it can. 15 | #[derive(Debug)] 16 | pub struct StreamingCompletion 17 | where 18 | Response: Message, 19 | Request: Message, 20 | { 21 | completion: mpsc::UnboundedReceiver, 22 | completion_guard: CompletionGuard, 23 | closed: bool, 24 | nexts: Vec, 25 | } 26 | 27 | /// SAFETY: There is no unsafe code in this implementation 28 | impl Unpin for StreamingCompletion 29 | where 30 | Response: Message, 31 | Request: Message, 32 | { 33 | } 34 | 35 | const LIMIT: usize = 16; 36 | 37 | impl StreamingCompletion 38 | where 39 | Response: Message, 40 | Request: Message, 41 | { 42 | pub(crate) fn new( 43 | completion: mpsc::UnboundedReceiver, 44 | completion_guard: CompletionGuard, 45 | ) -> Self { 46 | Self { 47 | completion, 48 | completion_guard, 49 | closed: false, 50 | nexts: Vec::with_capacity(LIMIT), 51 | } 52 | } 53 | } 54 | 55 | impl futures::Stream for StreamingCompletion 56 | where 57 | Response: Message, 58 | Request: Message, 59 | { 60 | type Item = crate::Result; 61 | 62 | fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { 63 | if self.closed { 64 | return Poll::Ready(None); 65 | } 66 | if self.nexts.is_empty() { 67 | let Self { 68 | completion, nexts, .. 69 | } = &mut *self; 70 | let received = pin!(completion).poll_recv_many(context, nexts, LIMIT); 71 | match received { 72 | Poll::Ready(count) => { 73 | if count == 0 { 74 | self.closed = true; 75 | self.completion_guard.set_closed(); 76 | return Poll::Ready(Some(Err(crate::Error::Finished))); 77 | } 78 | // because it is a vector, we have to consume in reverse order. This is because 79 | // of the poll_recv_many argument type. 80 | nexts.reverse(); 81 | } 82 | Poll::Pending => return Poll::Pending, 83 | } 84 | } 85 | match self.nexts.pop() { 86 | Some(next) => Poll::Ready(Some(Ok(next))), 87 | None => { 88 | log::error!("unexpected empty nexts"); 89 | Poll::Ready(None) 90 | } 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /protosocket-rpc/src/client/reactor/completion_unary.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future::Future, 3 | pin::{pin, Pin}, 4 | task::{Context, Poll}, 5 | }; 6 | 7 | use tokio::sync::oneshot; 8 | 9 | use super::completion_registry::CompletionGuard; 10 | use crate::Message; 11 | 12 | /// A completion for a unary RPC. 13 | #[derive(Debug)] 14 | pub struct UnaryCompletion 15 | where 16 | Response: Message, 17 | Request: Message, 18 | { 19 | completion: oneshot::Receiver>, 20 | completion_guard: CompletionGuard, 21 | } 22 | 23 | impl UnaryCompletion 24 | where 25 | Response: Message, 26 | Request: Message, 27 | { 28 | pub(crate) fn new( 29 | completion: oneshot::Receiver>, 30 | completion_guard: CompletionGuard, 31 | ) -> Self { 32 | Self { 33 | completion, 34 | completion_guard, 35 | } 36 | } 37 | } 38 | 39 | impl Future for UnaryCompletion 40 | where 41 | Response: Message, 42 | Request: Message, 43 | { 44 | type Output = crate::Result; 45 | 46 | fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll { 47 | match pin!(&mut self.completion).poll(context) { 48 | Poll::Ready(result) => { 49 | self.completion_guard.set_closed(); 50 | match result { 51 | Ok(done) => Poll::Ready(done), 52 | Err(_cancelled) => Poll::Ready(Err(crate::Error::CancelledRemotely)), 53 | } 54 | } 55 | Poll::Pending => Poll::Pending, 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /protosocket-rpc/src/client/reactor/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod completion_reactor; 2 | pub mod completion_registry; 3 | pub mod completion_streaming; 4 | pub mod completion_unary; 5 | -------------------------------------------------------------------------------- /protosocket-rpc/src/client/reactor/rpc_drop_guard.rs: -------------------------------------------------------------------------------- 1 | use crate::Message; 2 | 3 | #[derive(Debug)] 4 | pub struct RpcDropGuard 5 | where 6 | Request: Message, 7 | { 8 | cancellation_submission_queue: tokio::sync::mpsc::Sender, 9 | message_id: u64, 10 | completed: bool, 11 | } 12 | 13 | impl RpcDropGuard 14 | where 15 | Request: Message, 16 | { 17 | pub fn new( 18 | cancellation_submission_queue: tokio::sync::mpsc::Sender, 19 | message_id: u64, 20 | ) -> Self { 21 | Self { 22 | cancellation_submission_queue, 23 | message_id, 24 | completed: false, 25 | } 26 | } 27 | 28 | /// Set this to avoid sending a cancellation message when the guard is dropped. 29 | pub fn set_complete(&mut self) { 30 | self.completed = true; 31 | } 32 | 33 | pub fn is_complete(&self) -> bool { 34 | self.completed 35 | } 36 | } 37 | 38 | impl Drop for RpcDropGuard 39 | where 40 | Request: Message, 41 | { 42 | fn drop(&mut self) { 43 | if !self.completed { 44 | let mut message = Request::cancelled(); 45 | message.set_message_id(self.message_id); 46 | if let Err(e) = self.cancellation_submission_queue.try_send(message) { 47 | log::warn!("failed to send cancellation message: {:?}", e); 48 | } 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /protosocket-rpc/src/client/rpc_client.rs: -------------------------------------------------------------------------------- 1 | use std::sync::{atomic::AtomicBool, Arc}; 2 | 3 | use tokio::sync::{mpsc, oneshot}; 4 | 5 | use super::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor}; 6 | use super::reactor::completion_registry::{Completion, CompletionGuard, RpcRegistrar}; 7 | use super::reactor::{ 8 | completion_streaming::StreamingCompletion, completion_unary::UnaryCompletion, 9 | }; 10 | use crate::Message; 11 | 12 | /// A client for sending RPCs to a protosockets rpc server. 13 | /// 14 | /// It handles sending messages to the server and associating the responses. 15 | /// Messages are sent and received in any order, asynchronously, and support cancellation. 16 | /// To cancel an RPC, drop the response future. 17 | #[derive(Debug, Clone)] 18 | pub struct RpcClient 19 | where 20 | Request: Message, 21 | Response: Message, 22 | { 23 | #[allow(clippy::type_complexity)] 24 | in_flight_submission: RpcRegistrar, 25 | submission_queue: tokio::sync::mpsc::Sender, 26 | is_alive: Arc, 27 | } 28 | 29 | impl RpcClient 30 | where 31 | Request: Message, 32 | Response: Message, 33 | { 34 | pub(crate) fn new( 35 | submission_queue: mpsc::Sender, 36 | message_reactor: &RpcCompletionReactor>, 37 | ) -> Self { 38 | Self { 39 | submission_queue, 40 | in_flight_submission: message_reactor.in_flight_submission_handle(), 41 | is_alive: message_reactor.alive_handle(), 42 | } 43 | } 44 | 45 | /// Checking this before using the client does not guarantee that the client is still alive when you send 46 | /// your message. It may be useful for connection pool implementations - for example, [bb8::ManageConnection](https://github.com/djc/bb8/blob/09a043c001b3c15514d9f03991cfc87f7118a000/bb8/src/api.rs#L383-L384)'s 47 | /// is_valid and has_broken could be bound to this function to help the pool cycle out broken connections. 48 | pub fn is_alive(&self) -> bool { 49 | self.is_alive.load(std::sync::atomic::Ordering::Relaxed) 50 | } 51 | 52 | /// Send a server-streaming rpc to the server. 53 | /// 54 | /// This function only sends the request. You must consume the completion stream to get the response. 55 | #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."] 56 | pub async fn send_streaming( 57 | &self, 58 | request: Request, 59 | ) -> crate::Result> { 60 | let (sender, completion) = mpsc::unbounded_channel(); 61 | let completion_guard = self 62 | .send_message(Completion::RemoteStreaming(sender), request) 63 | .await?; 64 | 65 | let completion = StreamingCompletion::new(completion, completion_guard); 66 | 67 | Ok(completion) 68 | } 69 | 70 | /// Send a unary rpc to the server. 71 | /// 72 | /// This function only sends the request. You must await the completion to get the response. 73 | #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."] 74 | pub async fn send_unary( 75 | &self, 76 | request: Request, 77 | ) -> crate::Result> { 78 | let (completor, completion) = oneshot::channel(); 79 | let completion_guard = self 80 | .send_message(Completion::Unary(completor), request) 81 | .await?; 82 | 83 | let completion = UnaryCompletion::new(completion, completion_guard); 84 | 85 | Ok(completion) 86 | } 87 | 88 | async fn send_message( 89 | &self, 90 | completion: Completion, 91 | request: Request, 92 | ) -> crate::Result> { 93 | if !self.is_alive.load(std::sync::atomic::Ordering::Relaxed) { 94 | // early-out if the connection is closed 95 | return Err(crate::Error::ConnectionIsClosed); 96 | } 97 | let completion_guard = self.in_flight_submission.register_completion( 98 | request.message_id(), 99 | completion, 100 | self.submission_queue.clone(), 101 | ); 102 | self.submission_queue 103 | .send(request) 104 | .await 105 | .map_err(|_e| crate::Error::ConnectionIsClosed) 106 | .map(|_| completion_guard) 107 | } 108 | } 109 | 110 | #[cfg(test)] 111 | mod test { 112 | use std::future::Future; 113 | use std::pin::pin; 114 | use std::task::Context; 115 | use std::task::Poll; 116 | 117 | use futures::task::noop_waker_ref; 118 | 119 | use crate::client::reactor::completion_reactor::DoNothingMessageHandler; 120 | use crate::client::reactor::completion_reactor::RpcCompletionReactor; 121 | use crate::Message; 122 | 123 | use super::RpcClient; 124 | 125 | impl Message for u64 { 126 | fn message_id(&self) -> u64 { 127 | *self & 0xffffffff 128 | } 129 | 130 | fn control_code(&self) -> crate::ProtosocketControlCode { 131 | match *self >> 32 { 132 | 0 => crate::ProtosocketControlCode::Normal, 133 | 1 => crate::ProtosocketControlCode::Cancel, 134 | 2 => crate::ProtosocketControlCode::End, 135 | _ => unreachable!("invalid control code"), 136 | } 137 | } 138 | 139 | fn set_message_id(&mut self, message_id: u64) { 140 | *self = (*self & 0xf00000000) | message_id; 141 | } 142 | 143 | fn cancelled(message_id: u64) -> Self { 144 | (1_u64 << 32) | message_id 145 | } 146 | 147 | fn ended(message_id: u64) -> Self { 148 | (2 << 32) | message_id 149 | } 150 | } 151 | 152 | fn drive_future(f: F) -> F::Output { 153 | let mut f = pin!(f); 154 | loop { 155 | let next = f.as_mut().poll(&mut Context::from_waker(noop_waker_ref())); 156 | if let Poll::Ready(result) = next { 157 | break result; 158 | } 159 | } 160 | } 161 | 162 | #[allow(clippy::type_complexity)] 163 | fn get_client() -> ( 164 | tokio::sync::mpsc::Receiver, 165 | RpcClient, 166 | RpcCompletionReactor>, 167 | ) { 168 | let (sender, remote_end) = tokio::sync::mpsc::channel::(10); 169 | let rpc_reactor = RpcCompletionReactor::::new(DoNothingMessageHandler::default()); 170 | let client = RpcClient::new(sender, &rpc_reactor); 171 | (remote_end, client, rpc_reactor) 172 | } 173 | 174 | #[test] 175 | fn unary_drop_cancel() { 176 | let (mut remote_end, client, _reactor) = get_client(); 177 | 178 | let response = drive_future(client.send_unary(4)).expect("can send"); 179 | assert_eq!(4, remote_end.blocking_recv().expect("a request is sent")); 180 | assert!(remote_end.is_empty(), "no more messages yet"); 181 | 182 | drop(response); 183 | 184 | assert_eq!( 185 | (1 << 32) + 4, 186 | remote_end.blocking_recv().expect("a cancel is sent") 187 | ); 188 | } 189 | 190 | #[test] 191 | fn streaming_drop_cancel() { 192 | let (mut remote_end, client, _reactor) = get_client(); 193 | 194 | let response = drive_future(client.send_streaming(4)).expect("can send"); 195 | assert_eq!(4, remote_end.blocking_recv().expect("a request is sent")); 196 | assert!(remote_end.is_empty(), "no more messages yet"); 197 | 198 | drop(response); 199 | 200 | assert_eq!( 201 | (1 << 32) + 4, 202 | remote_end.blocking_recv().expect("a cancel is sent") 203 | ); 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /protosocket-rpc/src/error.rs: -------------------------------------------------------------------------------- 1 | /// Result type for protosocket-rpc-client. 2 | pub type Result = std::result::Result; 3 | 4 | /// Error type for protosocket-rpc-client. 5 | #[derive(Debug, thiserror::Error)] 6 | pub enum Error { 7 | #[error("IO failure: {0}")] 8 | IoFailure(#[from] std::io::Error), 9 | #[error("Rpc was cancelled remotely")] 10 | CancelledRemotely, 11 | #[error("Connection is closed")] 12 | ConnectionIsClosed, 13 | #[error("Rpc finished")] 14 | Finished, 15 | } 16 | -------------------------------------------------------------------------------- /protosocket-rpc/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Protosocket RPC 2 | //! 3 | //! This crate provides an rpc-style client and server for the protosocket 4 | //! protocol. You can use whatever encoding you want, but you must provide both a 5 | //! `Serializer` and a `Deserializer` for your messages. If you use `prost`, 6 | //! you can use the `protosocket-prost` crate to provide these implementations. 7 | //! See example-proto for an example of how to use this crate with protocol buffers. 8 | //! 9 | //! Messages must provide a `Message` implementation, which includes a `message_id` 10 | //! and a `control_code`. The `message_id` is used to correlate requests and responses, 11 | //! while the `control_code` is used to provide special handling for messages. You can 12 | //! receive Cancel when an rpc is aborted, and End when a streaming rpc is complete. 13 | //! 14 | //! This RPC client is medium-low level wrapper around the low level protosocket crate, 15 | //! adding a layer of RPC semantics. You are expected to write a wrapper with the functions 16 | //! that make sense for your application, and use this client as the transport layer. 17 | //! 18 | //! Clients and servers handle RPC cancellation. 19 | //! 20 | //! Clients and servers need to agree about the request and response semantics. While it is 21 | //! supported to have dynamic streaming/unary response types, it is recommended to instead 22 | //! use separate rpc-initiating messages for streaming and unary rpcs. 23 | 24 | mod error; 25 | mod message; 26 | 27 | pub mod client; 28 | pub mod server; 29 | 30 | pub use error::{Error, Result}; 31 | pub use message::{Message, ProtosocketControlCode}; 32 | -------------------------------------------------------------------------------- /protosocket-rpc/src/message.rs: -------------------------------------------------------------------------------- 1 | /// A protosocket message. 2 | pub trait Message: std::fmt::Debug + Send + Unpin + 'static { 3 | /// This is used to relate requests to responses. An RPC response has the same id as the request that generated it. 4 | fn message_id(&self) -> u64; 5 | 6 | /// Set the protosocket behavior of this message. 7 | fn control_code(&self) -> ProtosocketControlCode; 8 | 9 | /// This is used to relate requests to responses. An RPC response has the same id as the request that generated it. 10 | /// When the message is sent, protosocket will set this value. 11 | fn set_message_id(&mut self, message_id: u64); 12 | 13 | /// Create a message with a message with a cancel control code - used by the framework to handle cancellation. 14 | fn cancelled(message_id: u64) -> Self; 15 | 16 | /// Create a message with a message with an ended control code - used by the framework to handle streaming completion. 17 | fn ended(message_id: u64) -> Self; 18 | } 19 | 20 | #[derive(Debug, Clone, Copy)] 21 | #[repr(u8)] 22 | pub enum ProtosocketControlCode { 23 | /// No special behavior 24 | Normal = 0, 25 | /// Cancel processing the message with this message's id 26 | Cancel = 1, 27 | /// End processing the message with this message's id - for response streaming 28 | End = 2, 29 | } 30 | 31 | impl ProtosocketControlCode { 32 | pub fn from_u8(value: u8) -> Self { 33 | match value { 34 | 0 => Self::Normal, 35 | 1 => Self::Cancel, 36 | 2 => Self::End, 37 | _ => Self::Cancel, 38 | } 39 | } 40 | 41 | pub fn as_u8(&self) -> u8 { 42 | match self { 43 | Self::Normal => 0, 44 | Self::Cancel => 1, 45 | Self::End => 2, 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /protosocket-rpc/src/server/abortable.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future::Future, 3 | pin::{pin, Pin}, 4 | sync::{ 5 | atomic::{AtomicUsize, Ordering}, 6 | Arc, 7 | }, 8 | task::{Context, Poll}, 9 | }; 10 | 11 | use futures::{task::AtomicWaker, Stream}; 12 | 13 | impl Unpin for IdentifiableAbortable {} 14 | 15 | #[derive(Debug)] 16 | pub struct IdentifiableAbortable { 17 | f: F, 18 | aborted: Arc, 19 | waker: Arc, 20 | id: u64, 21 | } 22 | 23 | impl IdentifiableAbortable { 24 | pub fn new(id: u64, f: F) -> (Self, IdentifiableAbortHandle) { 25 | let aborted = Arc::new(AtomicUsize::new(0)); 26 | let waker = Arc::new(AtomicWaker::new()); 27 | ( 28 | Self { 29 | f, 30 | aborted: aborted.clone(), 31 | waker: waker.clone(), 32 | id, 33 | }, 34 | IdentifiableAbortHandle { aborted, waker }, 35 | ) 36 | } 37 | } 38 | 39 | #[derive(Debug)] 40 | pub enum AbortableState { 41 | Abort, 42 | Aborted, 43 | Ready(T), 44 | } 45 | 46 | impl Future for IdentifiableAbortable 47 | where 48 | F: Future + Unpin, 49 | { 50 | type Output = (u64, AbortableState>); 51 | 52 | fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll { 53 | let state = self.aborted.load(Ordering::Relaxed); 54 | if 1 == state { 55 | self.aborted.store(2, Ordering::Relaxed); 56 | return Poll::Ready((self.id, AbortableState::Abort)); 57 | } 58 | if 2 == state { 59 | return Poll::Ready((self.id, AbortableState::Aborted)); 60 | } 61 | self.waker.register(context.waker()); 62 | pin!(&mut self.f) 63 | .poll(context) 64 | .map(|output| (self.id, AbortableState::Ready(Ok(output)))) 65 | } 66 | } 67 | 68 | impl Stream for IdentifiableAbortable 69 | where 70 | S: Stream + Unpin, 71 | { 72 | type Item = (u64, AbortableState>); 73 | 74 | fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { 75 | self.waker.register(context.waker()); 76 | match self.aborted.load(Ordering::Relaxed) { 77 | 0 => { 78 | match pin!(&mut self.f).poll_next(context) { 79 | Poll::Ready(next) => { 80 | match next { 81 | Some(next) => { 82 | Poll::Ready(Some((self.id, AbortableState::Ready(Ok(next))))) 83 | } 84 | None => { 85 | // stream is done 86 | self.aborted.store(3, Ordering::Relaxed); 87 | Poll::Ready(Some(( 88 | self.id, 89 | AbortableState::Ready(Err(crate::Error::Finished)), 90 | ))) 91 | } 92 | } 93 | } 94 | Poll::Pending => Poll::Pending, 95 | } 96 | } 97 | 1 => { 98 | self.aborted.store(2, Ordering::Relaxed); 99 | Poll::Ready(Some((self.id, AbortableState::Abort))) 100 | } 101 | 2 => { 102 | self.aborted.store(3, Ordering::Relaxed); 103 | Poll::Ready(Some((self.id, AbortableState::Aborted))) 104 | } 105 | _ => Poll::Ready(None), 106 | } 107 | } 108 | } 109 | 110 | #[derive(Debug)] 111 | pub struct IdentifiableAbortHandle { 112 | aborted: Arc, 113 | waker: Arc, 114 | } 115 | impl IdentifiableAbortHandle { 116 | /// Send an abort to the future or stream. 117 | pub fn abort(&self) { 118 | let _ = self 119 | .aborted 120 | .compare_exchange(0, 1, Ordering::Release, Ordering::Relaxed); 121 | self.waker.wake(); 122 | } 123 | 124 | /// Mark the future or stream as externally cancelled - don't send a cancellation 125 | pub fn mark_aborted(&self) { 126 | self.aborted.store(2, Ordering::Relaxed); 127 | self.waker.wake(); 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /protosocket-rpc/src/server/connection_server.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | future::Future, 4 | pin::{pin, Pin}, 5 | task::{Context, Poll}, 6 | }; 7 | 8 | use futures::{ 9 | stream::{FuturesUnordered, SelectAll}, 10 | Stream, 11 | }; 12 | use tokio::sync::mpsc; 13 | use tokio_util::sync::PollSender; 14 | 15 | use crate::{server::RpcKind, Error, Message, ProtosocketControlCode}; 16 | 17 | use super::{ 18 | abortable::{AbortableState, IdentifiableAbortHandle, IdentifiableAbortable}, 19 | ConnectionService, 20 | }; 21 | 22 | #[derive(Debug)] 23 | pub struct RpcConnectionServer 24 | where 25 | TConnectionServer: ConnectionService, 26 | { 27 | connection_server: TConnectionServer, 28 | inbound: mpsc::UnboundedReceiver<::Request>, 29 | outbound: PollSender<::Response>, 30 | next_messages_buffer: Vec<::Request>, 31 | outstanding_unary_rpcs: 32 | FuturesUnordered>, 33 | outstanding_streaming_rpcs: SelectAll>, 34 | aborts: HashMap, 35 | } 36 | 37 | impl Future for RpcConnectionServer 38 | where 39 | TConnectionServer: ConnectionService, 40 | { 41 | type Output = Result<(), crate::Error>; 42 | 43 | fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll { 44 | // receive new messages 45 | if let Some(early_out) = self.as_mut().poll_receive_buffer(context) { 46 | return early_out; 47 | } 48 | // either we're pending on inbound or we're awake 49 | self.as_mut().handle_message_buffer(); 50 | 51 | // retire and advance outstanding rpcs 52 | if let Some(early_out) = self.as_mut().poll_advance_unary_rpcs(context) { 53 | return early_out; 54 | } 55 | if let Some(early_out) = self.poll_advance_streaming_rpcs(context) { 56 | return early_out; 57 | } 58 | 59 | Poll::Pending 60 | } 61 | } 62 | 63 | impl RpcConnectionServer 64 | where 65 | TConnectionServer: ConnectionService, 66 | { 67 | pub fn new( 68 | connection_server: TConnectionServer, 69 | inbound: mpsc::UnboundedReceiver<::Request>, 70 | outbound: mpsc::Sender<::Response>, 71 | ) -> Self { 72 | Self { 73 | connection_server, 74 | inbound, 75 | outbound: PollSender::new(outbound), 76 | next_messages_buffer: Default::default(), 77 | outstanding_unary_rpcs: Default::default(), 78 | outstanding_streaming_rpcs: Default::default(), 79 | aborts: Default::default(), 80 | } 81 | } 82 | 83 | fn poll_advance_unary_rpcs( 84 | mut self: Pin<&mut Self>, 85 | context: &mut Context<'_>, 86 | ) -> Option>> { 87 | loop { 88 | match pin!(&mut self.outbound).poll_reserve(context) { 89 | Poll::Ready(Ok(())) => { 90 | // ready to send 91 | } 92 | Poll::Ready(Err(_)) => { 93 | log::debug!("outbound connection is closed"); 94 | return Some(Poll::Ready(Err(crate::Error::ConnectionIsClosed))); 95 | } 96 | Poll::Pending => { 97 | log::debug!("no room in outbound connection"); 98 | break; 99 | } 100 | } 101 | 102 | match pin!(&mut self.outstanding_unary_rpcs).poll_next(context) { 103 | Poll::Ready(unary_done) => { 104 | match unary_done { 105 | Some((id, AbortableState::Ready(Ok(response)))) => { 106 | self.aborts.remove(&id); 107 | if let Err(_e) = self.outbound.send_item(response) { 108 | log::debug!("outbound connection is closed"); 109 | return Some(Poll::Ready(Err(crate::Error::ConnectionIsClosed))); 110 | } 111 | } 112 | Some((id, AbortableState::Ready(Err(e)))) => { 113 | let abort = self.aborts.remove(&id); 114 | match e { 115 | Error::IoFailure(error) => { 116 | log::warn!("{id} io failure while servicing rpc: {error:?}"); 117 | if let Some(abort) = abort { 118 | abort.abort(); 119 | } 120 | } 121 | Error::CancelledRemotely => { 122 | log::debug!("{id} rpc cancelled remotely"); 123 | if let Some(abort) = abort { 124 | abort.abort(); 125 | } 126 | } 127 | Error::ConnectionIsClosed => { 128 | log::debug!("{id} rpc cancelled remotely"); 129 | if let Some(abort) = abort { 130 | abort.abort(); 131 | } 132 | } 133 | Error::Finished => { 134 | log::debug!("{id} unary rpc ended"); 135 | if let Some(abort) = abort { 136 | if let Err(_e) = self.outbound.send_item( 137 | ::ended(id), 138 | ) { 139 | log::debug!("outbound connection is closed"); 140 | return Some(Poll::Ready(Err( 141 | crate::Error::ConnectionIsClosed, 142 | ))); 143 | } 144 | 145 | abort.mark_aborted(); 146 | } 147 | } 148 | } 149 | // cancelled 150 | } 151 | Some((id, AbortableState::Abort)) => { 152 | // This happens when the upstream stuff is dropped and there are no messages that can be produced. We'll send a cancellation. 153 | log::debug!("{id} unary rpc abort"); 154 | if let Some(abort) = self.aborts.remove(&id) { 155 | abort.abort(); 156 | } 157 | } 158 | Some((id, AbortableState::Aborted)) => { 159 | // This happens when the upstream stuff is dropped and there are no messages that can be produced. We'll send a cancellation. 160 | log::debug!("{id} unary rpc done"); 161 | if let Some(abort) = self.aborts.remove(&id) { 162 | abort.mark_aborted(); 163 | } 164 | } 165 | None => { 166 | // nothing to wait for 167 | break; 168 | } 169 | } 170 | } 171 | Poll::Pending => break, 172 | } 173 | } 174 | None 175 | } 176 | 177 | // I want to join this with the above function but it is annoying to zip the SelectAll and FuturesUnordered together. 178 | // This should be possible today with futures::Stream but I need to sit and stare at it for a while to figure out how. 179 | fn poll_advance_streaming_rpcs( 180 | mut self: Pin<&mut Self>, 181 | context: &mut Context<'_>, 182 | ) -> Option>> { 183 | loop { 184 | match pin!(&mut self.outbound).poll_reserve(context) { 185 | Poll::Ready(Ok(())) => { 186 | // ready to send 187 | } 188 | Poll::Ready(Err(_)) => { 189 | log::debug!("outbound connection is closed"); 190 | return Some(Poll::Ready(Err(crate::Error::ConnectionIsClosed))); 191 | } 192 | Poll::Pending => { 193 | log::debug!("no room in outbound connection"); 194 | break; 195 | } 196 | } 197 | 198 | match pin!(&mut self.outstanding_streaming_rpcs).poll_next(context) { 199 | Poll::Ready(streaming_next) => { 200 | match streaming_next { 201 | Some((id, AbortableState::Ready(Ok(next)))) => { 202 | log::debug!("{id} streaming rpc next {next:?}"); 203 | if let Err(_e) = self.outbound.send_item(next) { 204 | log::debug!("outbound connection is closed"); 205 | return Some(Poll::Ready(Err(crate::Error::ConnectionIsClosed))); 206 | } 207 | } 208 | Some((id, AbortableState::Ready(Err(e)))) => { 209 | let abort = self.aborts.remove(&id); 210 | match e { 211 | Error::IoFailure(error) => { 212 | log::warn!("{id} io failure while servicing rpc: {error:?}"); 213 | if let Some(abort) = abort { 214 | abort.abort(); 215 | } 216 | } 217 | Error::CancelledRemotely => { 218 | log::debug!("{id} rpc cancelled remotely"); 219 | if let Some(abort) = abort { 220 | abort.abort(); 221 | } 222 | } 223 | Error::ConnectionIsClosed => { 224 | log::debug!("{id} rpc cancelled remotely"); 225 | if let Some(abort) = abort { 226 | abort.abort(); 227 | } 228 | } 229 | Error::Finished => { 230 | log::debug!("{id} streaming rpc ended"); 231 | if let Some(abort) = abort { 232 | if let Err(_e) = self.outbound.send_item( 233 | ::ended(id), 234 | ) { 235 | log::debug!("outbound connection is closed"); 236 | return Some(Poll::Ready(Err( 237 | crate::Error::ConnectionIsClosed, 238 | ))); 239 | } 240 | abort.mark_aborted(); 241 | } 242 | } 243 | } 244 | } 245 | Some((id, AbortableState::Abort)) => { 246 | // This happens when the upstream stuff is dropped and there are no messages that can be produced. We'll send a cancellation. 247 | log::debug!("{id} streaming rpc abort"); 248 | if let Some(abort) = self.aborts.remove(&id) { 249 | abort.abort(); 250 | } 251 | } 252 | Some((id, AbortableState::Aborted)) => { 253 | log::debug!("{id} streaming rpc done"); 254 | if let Some(abort) = self.aborts.remove(&id) { 255 | abort.mark_aborted(); 256 | } 257 | } 258 | None => { 259 | // nothing to wait for 260 | break; 261 | } 262 | } 263 | } 264 | Poll::Pending => break, 265 | } 266 | } 267 | None 268 | } 269 | 270 | fn poll_receive_buffer( 271 | mut self: Pin<&mut Self>, 272 | context: &mut Context<'_>, 273 | ) -> Option>> { 274 | const MAXIMUM_MESSAGES_PER_POLL: usize = 128; 275 | if self.next_messages_buffer.is_empty() { 276 | let Self { 277 | inbound, 278 | next_messages_buffer, 279 | .. 280 | } = &mut *self; 281 | match inbound.poll_recv_many(context, next_messages_buffer, MAXIMUM_MESSAGES_PER_POLL) { 282 | Poll::Ready(0) => { 283 | return Some(Poll::Ready(Ok(()))); 284 | } 285 | Poll::Ready(_count) => { 286 | // ugh, I know. but poll_recv_many is much cheaper than poll_recv, 287 | // and poll_recv requires &mut Vec. Otherwise this would be a VecDeque with no reverse. 288 | next_messages_buffer.reverse(); 289 | // possible there is more, but let's just do one batch at a time 290 | context.waker().wake_by_ref(); 291 | } 292 | Poll::Pending => {} 293 | } 294 | } 295 | None 296 | } 297 | 298 | fn handle_message_buffer(mut self: Pin<&mut Self>) { 299 | while let Some(next_message) = self.next_messages_buffer.pop() { 300 | let message_id = next_message.message_id(); 301 | match next_message.control_code() { 302 | ProtosocketControlCode::Normal => { 303 | match self.connection_server.new_rpc(next_message) { 304 | RpcKind::Unary(completion) => { 305 | let (completion, abort) = 306 | IdentifiableAbortable::new(message_id, completion); 307 | self.aborts.insert(message_id, abort); 308 | self.outstanding_unary_rpcs.push(completion); 309 | } 310 | RpcKind::Streaming(completion) => { 311 | let (completion, abort) = 312 | IdentifiableAbortable::new(message_id, completion); 313 | self.aborts.insert(message_id, abort); 314 | self.outstanding_streaming_rpcs.push(completion); 315 | } 316 | RpcKind::Unknown => { 317 | log::debug!("skipping message {message_id}"); 318 | } 319 | } 320 | } 321 | ProtosocketControlCode::Cancel => { 322 | if let Some(abort) = self.aborts.remove(&message_id) { 323 | log::debug!("cancelling message {message_id}"); 324 | abort.mark_aborted(); 325 | } else { 326 | log::debug!("received cancellation for untracked message {message_id}"); 327 | } 328 | } 329 | ProtosocketControlCode::End => { 330 | log::debug!("received end message {message_id}"); 331 | } 332 | } 333 | } 334 | } 335 | } 336 | 337 | #[cfg(test)] 338 | mod test { 339 | use std::{ 340 | future::Future, 341 | pin::pin, 342 | ptr, 343 | task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, 344 | }; 345 | 346 | use futures::{FutureExt, StreamExt}; 347 | use tokio::sync::mpsc; 348 | 349 | use crate::{ 350 | server::{ConnectionService, RpcKind}, 351 | ProtosocketControlCode, 352 | }; 353 | 354 | use super::RpcConnectionServer; 355 | 356 | #[derive(Clone, PartialEq, Eq, prost::Message, PartialOrd, Ord)] 357 | pub struct Message { 358 | #[prost(uint64, tag = "1")] 359 | pub id: u64, 360 | #[prost(uint32, tag = "2")] 361 | pub code: u32, 362 | #[prost(uint64, tag = "3")] 363 | pub n: u64, 364 | } 365 | 366 | impl crate::Message for Message { 367 | fn message_id(&self) -> u64 { 368 | self.id 369 | } 370 | 371 | fn control_code(&self) -> crate::ProtosocketControlCode { 372 | crate::ProtosocketControlCode::from_u8(self.code as u8) 373 | } 374 | 375 | fn set_message_id(&mut self, message_id: u64) { 376 | self.id = message_id; 377 | } 378 | 379 | fn cancelled(message_id: u64) -> Self { 380 | Self { 381 | id: message_id, 382 | n: 0, 383 | code: ProtosocketControlCode::Cancel.as_u8() as u32, 384 | } 385 | } 386 | 387 | fn ended(message_id: u64) -> Self { 388 | Self { 389 | id: message_id, 390 | n: 0, 391 | code: ProtosocketControlCode::End.as_u8() as u32, 392 | } 393 | } 394 | } 395 | 396 | const HANGING_UNARY_MESSAGE: u64 = 2000; 397 | const HANGING_STREAMING_MESSAGE: u64 = 3000; 398 | struct TestConnectionService; 399 | impl std::fmt::Debug for TestConnectionService { 400 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 401 | f.debug_struct("TestConnectionService").finish() 402 | } 403 | } 404 | 405 | impl ConnectionService for TestConnectionService { 406 | type Request = Message; 407 | type Response = Message; 408 | // Boxing is used for convenience in tests. You should try to use a static type in your real code. 409 | type UnaryFutureType = futures::future::BoxFuture<'static, Message>; 410 | type StreamType = futures::stream::BoxStream<'static, Message>; 411 | 412 | fn new_rpc( 413 | &mut self, 414 | request: Self::Request, 415 | ) -> crate::server::RpcKind { 416 | if request.id == HANGING_UNARY_MESSAGE { 417 | RpcKind::Unary(futures::future::pending().boxed()) 418 | } else if request.id == HANGING_STREAMING_MESSAGE { 419 | RpcKind::Streaming(futures::stream::pending().boxed()) 420 | } else if request.id < 1000 { 421 | RpcKind::Unary( 422 | futures::future::ready(Message { 423 | id: request.id, 424 | code: ProtosocketControlCode::Normal.as_u8() as u32, 425 | n: request.n + 1, 426 | }) 427 | .boxed(), 428 | ) 429 | } else { 430 | RpcKind::Streaming( 431 | futures::stream::iter((0..request.n).map(move |n| Message { 432 | id: request.id, 433 | code: ProtosocketControlCode::Normal.as_u8() as u32, 434 | n, 435 | })) 436 | .boxed(), 437 | ) 438 | } 439 | } 440 | } 441 | 442 | pub fn noop_waker() -> Waker { 443 | const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( 444 | |_| RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE), 445 | |_| {}, 446 | |_| {}, 447 | |_| {}, 448 | ); 449 | let raw = RawWaker::new(ptr::null(), &NOOP_WAKER_VTABLE); 450 | // SAFETY: the contracts for RawWaker and RawWakerVTable are trivially upheld by always making new wakers 451 | unsafe { Waker::from_raw(raw) } 452 | } 453 | 454 | fn test_server( 455 | outbound_buffer: usize, 456 | ) -> ( 457 | mpsc::UnboundedSender, 458 | mpsc::Receiver, 459 | RpcConnectionServer, 460 | ) { 461 | let (inbound_sender, inbound) = mpsc::unbounded_channel(); 462 | let (outbound, outbound_receiver) = mpsc::channel(outbound_buffer); 463 | let server = RpcConnectionServer::new(TestConnectionService, inbound, outbound); 464 | (inbound_sender, outbound_receiver, server) 465 | } 466 | 467 | #[track_caller] 468 | fn assert_next( 469 | message: Message, 470 | outbound_receiver: &mut mpsc::Receiver, 471 | context: &mut Context<'_>, 472 | ) { 473 | assert_eq!( 474 | Poll::Ready(Some(message)), 475 | outbound_receiver.poll_recv(context) 476 | ); 477 | } 478 | 479 | #[track_caller] 480 | fn poll_next( 481 | outbound_receiver: &mut mpsc::Receiver, 482 | context: &mut Context<'_>, 483 | ) -> Message { 484 | match outbound_receiver.poll_recv(context) { 485 | Poll::Ready(Some(message)) => message, 486 | got => panic!("expected message, got {got:?}"), 487 | } 488 | } 489 | 490 | #[test] 491 | fn unary() { 492 | let waker = noop_waker(); 493 | let mut context = Context::from_waker(&waker); 494 | 495 | let (inbound_sender, mut outbound_receiver, mut server) = test_server(3); 496 | 497 | // test messages below 1000 are unary. Response is n + 1 498 | let _ = inbound_sender.send(Message { 499 | id: 1, 500 | code: 0, 501 | n: 1, 502 | }); 503 | 504 | assert_eq!( 505 | Poll::Pending, 506 | outbound_receiver.poll_recv(&mut context), 507 | "nothing should be sent until the server advances to accept the message" 508 | ); 509 | 510 | assert!( 511 | pin!(&mut server).poll(&mut context).is_pending(), 512 | "server should be pending forever" 513 | ); 514 | assert_eq!( 515 | 0, 516 | server.outstanding_unary_rpcs.len(), 517 | "it completed in one poll" 518 | ); 519 | 520 | assert_next( 521 | Message { 522 | id: 1, 523 | code: 0, 524 | n: 2, 525 | }, 526 | &mut outbound_receiver, 527 | &mut context, 528 | ); 529 | } 530 | 531 | #[test] 532 | fn concurrent_unary() { 533 | let waker = noop_waker(); 534 | let mut context = Context::from_waker(&waker); 535 | 536 | let (inbound_sender, mut outbound_receiver, mut server) = test_server(3); 537 | 538 | let _ = inbound_sender.send(Message { 539 | id: 1, 540 | code: 0, 541 | n: 1, 542 | }); 543 | let _ = inbound_sender.send(Message { 544 | id: 2, 545 | code: 0, 546 | n: 3, 547 | }); 548 | let _ = inbound_sender.send(Message { 549 | id: 3, 550 | code: 0, 551 | n: 5, 552 | }); 553 | 554 | // the server takes up to MAXIMUM_MESSAGES_PER_POLL per poll. I only submitted 3, so they should 555 | // all get processed in the a single round of poll. 556 | assert!( 557 | pin!(&mut server).poll(&mut context).is_pending(), 558 | "server should be pending forever" 559 | ); 560 | assert_eq!( 561 | 0, 562 | server.outstanding_unary_rpcs.len(), 563 | "it completed in one poll" 564 | ); 565 | 566 | let mut concurrent_completions = vec![ 567 | poll_next(&mut outbound_receiver, &mut context), 568 | poll_next(&mut outbound_receiver, &mut context), 569 | poll_next(&mut outbound_receiver, &mut context), 570 | ]; 571 | // they are allowed to complete in any order but I'd like a deterministic order for the assertion 572 | concurrent_completions.sort(); 573 | 574 | assert_eq!( 575 | vec![ 576 | Message { 577 | id: 1, 578 | code: 0, 579 | n: 2 580 | }, 581 | Message { 582 | id: 2, 583 | code: 0, 584 | n: 4 585 | }, 586 | Message { 587 | id: 3, 588 | code: 0, 589 | n: 6 590 | }, 591 | ], 592 | concurrent_completions, 593 | ); 594 | assert_eq!( 595 | Poll::Pending, 596 | outbound_receiver.poll_recv(&mut context), 597 | "no made up messages" 598 | ); 599 | } 600 | 601 | #[test] 602 | fn streaming() { 603 | let waker = noop_waker(); 604 | let mut context = Context::from_waker(&waker); 605 | 606 | let (inbound_sender, mut outbound_receiver, mut server) = test_server(3); 607 | // "test" messages at and above 1000 are streaming. Stream has responses n=0..n 608 | let _ = inbound_sender.send(Message { 609 | id: 1000, 610 | code: 0, 611 | n: 2, 612 | }); 613 | assert!( 614 | pin!(&mut server).poll(&mut context).is_pending(), 615 | "server should be pending forever" 616 | ); 617 | 618 | let first_message = poll_next(&mut outbound_receiver, &mut context); 619 | assert_eq!( 620 | 1, 621 | server.outstanding_streaming_rpcs.len(), 622 | "there should still be an outstanding rpc because the stream is not done" 623 | ); 624 | let messages = vec![ 625 | first_message, 626 | poll_next(&mut outbound_receiver, &mut context), 627 | poll_next(&mut outbound_receiver, &mut context), 628 | ]; 629 | // these must come in the correct order. 630 | 631 | assert_eq!( 632 | vec![ 633 | Message { 634 | id: 1000, 635 | code: 0, 636 | n: 0 637 | }, 638 | Message { 639 | id: 1000, 640 | code: 0, 641 | n: 1 642 | }, 643 | Message { 644 | id: 1000, 645 | code: ProtosocketControlCode::End.as_u8() as u32, 646 | n: 0 647 | }, 648 | ], 649 | messages, 650 | ); 651 | 652 | assert_eq!(1, server.outstanding_streaming_rpcs.len(), "server has not yet discovered that this rpc is complete. This might change if the poll batch process is changed"); 653 | assert!( 654 | pin!(&mut server).poll(&mut context).is_pending(), 655 | "server should be pending forever" 656 | ); 657 | assert_eq!( 658 | 0, 659 | server.outstanding_streaming_rpcs.len(), 660 | "all rpcs should be completed" 661 | ); 662 | assert_eq!( 663 | Poll::Pending, 664 | outbound_receiver.poll_recv(&mut context), 665 | "no made up messages" 666 | ); 667 | } 668 | 669 | #[test] 670 | fn streaming_concurrent() { 671 | let waker = noop_waker(); 672 | let mut context = Context::from_waker(&waker); 673 | 674 | let (inbound_sender, mut outbound_receiver, mut server) = test_server(3); 675 | // "test" messages at and above 1000 are streaming. Stream has responses n=0..n 676 | let _ = inbound_sender.send(Message { 677 | id: 1000, 678 | code: 0, 679 | n: 2, 680 | }); 681 | let _ = inbound_sender.send(Message { 682 | id: 1001, 683 | code: 0, 684 | n: 2, 685 | }); 686 | let _ = inbound_sender.send(Message { 687 | id: 1002, 688 | code: 0, 689 | n: 2, 690 | }); 691 | 692 | assert!( 693 | pin!(&mut server).poll(&mut context).is_pending(), 694 | "server should be pending forever" 695 | ); 696 | assert_eq!(3, server.outstanding_streaming_rpcs.len()); 697 | 698 | let mut messages = vec![ 699 | poll_next(&mut outbound_receiver, &mut context), 700 | poll_next(&mut outbound_receiver, &mut context), 701 | poll_next(&mut outbound_receiver, &mut context), 702 | ]; 703 | assert_eq!( 704 | Poll::Pending, 705 | outbound_receiver.poll_recv(&mut context), 706 | "outbound buffer is only 3. It is unknown if any of the rpcs are complete" 707 | ); 708 | assert!( 709 | pin!(&mut server).poll(&mut context).is_pending(), 710 | "server should be pending forever" 711 | ); 712 | messages.push(poll_next(&mut outbound_receiver, &mut context)); 713 | messages.push(poll_next(&mut outbound_receiver, &mut context)); 714 | messages.push(poll_next(&mut outbound_receiver, &mut context)); 715 | assert_eq!(Poll::Pending, outbound_receiver.poll_recv(&mut context), "though we only defined 6 messages, the server sends an End message for each gracefully ended stream"); 716 | assert!( 717 | pin!(&mut server).poll(&mut context).is_pending(), 718 | "server should be pending forever" 719 | ); 720 | messages.push(poll_next(&mut outbound_receiver, &mut context)); 721 | messages.push(poll_next(&mut outbound_receiver, &mut context)); 722 | messages.push(poll_next(&mut outbound_receiver, &mut context)); 723 | 724 | // The messages may be intermixed per-rpc, but they must be mutually in order per-rpc. 725 | // It is a weak assertion to sort these, because that would allow _reordered streams_ to pass the test. 726 | let first_rpc: Vec<_> = messages 727 | .iter() 728 | .filter(|message| message.id == 1000) 729 | .cloned() 730 | .collect(); 731 | let second_rpc: Vec<_> = messages 732 | .iter() 733 | .filter(|message| message.id == 1001) 734 | .cloned() 735 | .collect(); 736 | let third_rpc: Vec<_> = messages 737 | .iter() 738 | .filter(|message| message.id == 1002) 739 | .cloned() 740 | .collect(); 741 | 742 | assert_eq!( 743 | vec![ 744 | Message { 745 | id: 1000, 746 | code: 0, 747 | n: 0 748 | }, 749 | Message { 750 | id: 1000, 751 | code: 0, 752 | n: 1 753 | }, 754 | Message { 755 | id: 1000, 756 | code: ProtosocketControlCode::End.as_u8() as u32, 757 | n: 0 758 | }, 759 | ], 760 | first_rpc, 761 | ); 762 | assert_eq!( 763 | vec![ 764 | Message { 765 | id: 1001, 766 | code: 0, 767 | n: 0 768 | }, 769 | Message { 770 | id: 1001, 771 | code: 0, 772 | n: 1 773 | }, 774 | Message { 775 | id: 1001, 776 | code: ProtosocketControlCode::End.as_u8() as u32, 777 | n: 0 778 | }, 779 | ], 780 | second_rpc, 781 | ); 782 | assert_eq!( 783 | vec![ 784 | Message { 785 | id: 1002, 786 | code: 0, 787 | n: 0 788 | }, 789 | Message { 790 | id: 1002, 791 | code: 0, 792 | n: 1 793 | }, 794 | Message { 795 | id: 1002, 796 | code: ProtosocketControlCode::End.as_u8() as u32, 797 | n: 0 798 | }, 799 | ], 800 | third_rpc, 801 | ); 802 | // server may have 0-3 pending rpcs, but they should all complete with the next poll. 803 | assert!( 804 | pin!(&mut server).poll(&mut context).is_pending(), 805 | "server should be pending forever" 806 | ); 807 | assert_eq!( 808 | 0, 809 | server.outstanding_streaming_rpcs.len(), 810 | "all rpcs should be completed" 811 | ); 812 | assert_eq!( 813 | Poll::Pending, 814 | outbound_receiver.poll_recv(&mut context), 815 | "no made up messages" 816 | ); 817 | } 818 | 819 | // This test makes sure that the server drops a unary rpc when it asked to do so. 820 | #[test] 821 | fn unary_client_cancellation() { 822 | let waker = noop_waker(); 823 | let mut context = Context::from_waker(&waker); 824 | 825 | let (inbound_sender, mut outbound_receiver, mut server) = test_server(3); 826 | 827 | let _ = inbound_sender.send(Message { 828 | id: HANGING_UNARY_MESSAGE, 829 | code: 0, 830 | n: 1, 831 | }); 832 | assert!(pin!(&mut server).poll(&mut context).is_pending()); 833 | 834 | assert_eq!( 835 | 1, 836 | server.outstanding_unary_rpcs.len(), 837 | "it will never complete" 838 | ); 839 | 840 | let _ = inbound_sender.send(Message { 841 | id: HANGING_UNARY_MESSAGE, 842 | code: ProtosocketControlCode::Cancel.as_u8() as u32, 843 | n: 0, 844 | }); 845 | 846 | assert!( 847 | pin!(&mut server).poll(&mut context).is_pending(), 848 | "server should be pending forever" 849 | ); 850 | assert_eq!( 851 | 0, 852 | server.outstanding_unary_rpcs.len(), 853 | "all rpcs should be completed" 854 | ); 855 | assert_eq!( 856 | Poll::Pending, 857 | outbound_receiver.poll_recv(&mut context), 858 | "no made up messages" 859 | ); 860 | } 861 | 862 | // This test makes sure that the server drops a streaming rpc when it asked to do so. 863 | #[test] 864 | fn streaming_client_cancellation() { 865 | let waker = noop_waker(); 866 | let mut context = Context::from_waker(&waker); 867 | 868 | let (inbound_sender, mut outbound_receiver, mut server) = test_server(3); 869 | 870 | let _ = inbound_sender.send(Message { 871 | id: HANGING_STREAMING_MESSAGE, 872 | code: 0, 873 | n: 1, 874 | }); 875 | assert!(pin!(&mut server).poll(&mut context).is_pending()); 876 | 877 | assert_eq!( 878 | 1, 879 | server.outstanding_streaming_rpcs.len(), 880 | "it will never complete" 881 | ); 882 | 883 | let _ = inbound_sender.send(Message { 884 | id: HANGING_STREAMING_MESSAGE, 885 | code: ProtosocketControlCode::Cancel.as_u8() as u32, 886 | n: 0, 887 | }); 888 | 889 | assert!( 890 | pin!(&mut server).poll(&mut context).is_pending(), 891 | "server should be pending forever" 892 | ); 893 | assert_eq!( 894 | 0, 895 | server.outstanding_streaming_rpcs.len(), 896 | "all rpcs should be completed" 897 | ); 898 | assert_eq!( 899 | Poll::Pending, 900 | outbound_receiver.poll_recv(&mut context), 901 | "no made up messages" 902 | ); 903 | } 904 | } 905 | -------------------------------------------------------------------------------- /protosocket-rpc/src/server/mod.rs: -------------------------------------------------------------------------------- 1 | mod abortable; 2 | mod connection_server; 3 | // mod queue_reactor; 4 | mod rpc_submitter; 5 | mod server_traits; 6 | mod socket_server; 7 | 8 | pub use server_traits::{ConnectionService, RpcKind, SocketService}; 9 | pub use socket_server::SocketRpcServer; 10 | -------------------------------------------------------------------------------- /protosocket-rpc/src/server/rpc_submitter.rs: -------------------------------------------------------------------------------- 1 | use protosocket::{ConnectionBindings, Deserializer, MessageReactor}; 2 | use tokio::sync::mpsc; 3 | 4 | use super::SocketService; 5 | 6 | #[derive(Debug, Clone)] 7 | pub struct RpcSubmitter 8 | where 9 | TSocketService: SocketService, 10 | { 11 | sender: mpsc::UnboundedSender<::Message>, 12 | } 13 | impl RpcSubmitter 14 | where 15 | TSocketService: SocketService, 16 | { 17 | pub fn new() -> ( 18 | Self, 19 | mpsc::UnboundedReceiver<::Message>, 20 | ) { 21 | let (sender, receiver) = mpsc::unbounded_channel(); 22 | (Self { sender }, receiver) 23 | } 24 | } 25 | 26 | impl MessageReactor for RpcSubmitter 27 | where 28 | TSocketService: SocketService, 29 | { 30 | type Inbound = ::Message; 31 | 32 | fn on_inbound_messages( 33 | &mut self, 34 | messages: impl IntoIterator, 35 | ) -> protosocket::ReactorStatus { 36 | for message in messages.into_iter() { 37 | match self.sender.send(message) { 38 | Ok(_) => (), 39 | Err(e) => { 40 | log::warn!("failed to send message: {:?}", e); 41 | return protosocket::ReactorStatus::Disconnect; 42 | } 43 | } 44 | } 45 | protosocket::ReactorStatus::Continue 46 | } 47 | } 48 | 49 | impl ConnectionBindings for RpcSubmitter 50 | where 51 | TSocketService: SocketService, 52 | { 53 | type Deserializer = TSocketService::RequestDeserializer; 54 | type Serializer = TSocketService::ResponseSerializer; 55 | type Reactor = RpcSubmitter; 56 | } 57 | -------------------------------------------------------------------------------- /protosocket-rpc/src/server/server_traits.rs: -------------------------------------------------------------------------------- 1 | use std::{future::Future, net::SocketAddr}; 2 | 3 | use protosocket::{Deserializer, Serializer}; 4 | 5 | use crate::Message; 6 | 7 | /// SocketService receives connections and produces ConnectionServices. 8 | /// 9 | /// The SocketService is notified when a new connection is established. It is given the address of the 10 | /// remote peer and it returns a ConnectionService for that connection. You can think of this as the 11 | /// "connection factory" for your server. It is the "top" of your service stack. 12 | pub trait SocketService: 'static { 13 | /// The type of deserializer for incoming messages. 14 | type RequestDeserializer: Deserializer + 'static; 15 | /// The type of serializer for outgoing messages. 16 | type ResponseSerializer: Serializer + 'static; 17 | /// The type of connection service that will be created for each connection. 18 | type ConnectionService: ConnectionService< 19 | Request = ::Message, 20 | Response = ::Message, 21 | >; 22 | 23 | /// Create a new deserializer for incoming messages. 24 | fn deserializer(&self) -> Self::RequestDeserializer; 25 | /// Create a new serializer for outgoing messages. 26 | fn serializer(&self) -> Self::ResponseSerializer; 27 | 28 | /// Create a new ConnectionService for a new connection. 29 | fn new_connection_service(&self, address: SocketAddr) -> Self::ConnectionService; 30 | } 31 | 32 | /// A connection service receives rpcs from clients and sends responses. 33 | /// 34 | /// Each client connection gets a ConnectionService. You put your per-connection state in your 35 | /// ConnectionService implementation. 36 | /// 37 | /// Every interaction with a client is done via an RPC. You are called with the initiating message 38 | /// from the client, and you return the kind of response future that is used to complete the RPC. 39 | /// 40 | /// A ConnectionService is executed in the context of an RPC connection server, which is a future. 41 | /// This means you get `&mut self` when you are called with a new rpc. You can use simple mutable 42 | /// state per-connection; but if you need to share state between connections or elsewhere in your 43 | /// application, you will need to use an appropriate state sharing mechanism. 44 | pub trait ConnectionService: Send + Unpin + 'static { 45 | /// The type of request message, These messages initiate rpcs. 46 | type Request: Message; 47 | /// The type of response message, These messages complete rpcs, or are streamed from them. 48 | type Response: Message; 49 | /// The type of future that completes a unary rpc. 50 | type UnaryFutureType: Future + Send + Unpin; 51 | /// The type of stream that completes a streaming rpc. 52 | type StreamType: futures::Stream + Send + Unpin; 53 | 54 | /// Create a new rpc task completion. 55 | /// 56 | /// You can provide a concrete Future and it will be polled in the context of the Connection 57 | /// itself. This would limit your Connection and all of its outstanding rpc's to 1 cpu at a time. 58 | /// That might be good for your use case, or it might be suboptimal. 59 | /// You can of course also spawn a task and return a completion future that completes when the 60 | /// task completes, e.g., with a tokio::sync::oneshot or mpsc stream. In general, try to do as 61 | /// little as possible: Return a future (rather than a task handle) and let the ConnectionServer 62 | /// task poll it. This keeps your task count low and your wakes more tightly related to the 63 | /// cooperating tasks (e.g., ConnectionServer and Connection) that need to be woken. 64 | fn new_rpc( 65 | &mut self, 66 | initiating_message: Self::Request, 67 | ) -> RpcKind; 68 | } 69 | 70 | /// Type of rpc to be awaited 71 | pub enum RpcKind { 72 | /// This is a unary rpc. It will complete with a single response. 73 | Unary(Unary), 74 | /// This is a streaming rpc. It will complete with a stream of responses. 75 | Streaming(Streaming), 76 | /// This is an unknown rpc. It will be skipped. 77 | Unknown, 78 | } 79 | -------------------------------------------------------------------------------- /protosocket-rpc/src/server/socket_server.rs: -------------------------------------------------------------------------------- 1 | use std::future::Future; 2 | use std::io::Error; 3 | use std::pin::Pin; 4 | use std::task::Context; 5 | use std::task::Poll; 6 | 7 | use protosocket::Connection; 8 | use tokio::sync::mpsc; 9 | 10 | use super::connection_server::RpcConnectionServer; 11 | use super::rpc_submitter::RpcSubmitter; 12 | use super::server_traits::SocketService; 13 | 14 | /// A `SocketRpcServer` is a server future. It listens on a socket and spawns new connections, 15 | /// with a ConnectionService to handle each connection. 16 | /// 17 | /// Protosockets use monomorphic messages: You can only have 1 kind of message per service. 18 | /// The expected way to work with this is to use prost and protocol buffers to encode messages. 19 | /// 20 | /// The socket server hosts your SocketService. 21 | /// Your SocketService creates a ConnectionService for each new connection. 22 | /// Your ConnectionService manages one connection. It is Dropped when the connection is closed. 23 | pub struct SocketRpcServer 24 | where 25 | TSocketService: SocketService, 26 | { 27 | socket_server: TSocketService, 28 | listener: tokio::net::TcpListener, 29 | max_buffer_length: usize, 30 | max_queued_outbound_messages: usize, 31 | } 32 | 33 | impl SocketRpcServer 34 | where 35 | TSocketService: SocketService, 36 | { 37 | /// Construct a new `SocketRpcServer` listening on the provided address. 38 | pub async fn new( 39 | address: std::net::SocketAddr, 40 | socket_server: TSocketService, 41 | ) -> crate::Result { 42 | let listener = tokio::net::TcpListener::bind(address).await?; 43 | Ok(Self { 44 | socket_server, 45 | listener, 46 | max_buffer_length: 16 * (2 << 20), 47 | max_queued_outbound_messages: 128, 48 | }) 49 | } 50 | 51 | /// Set the maximum buffer length for connections created by this server after the setting is applied. 52 | pub fn set_max_buffer_length(&mut self, max_buffer_length: usize) { 53 | self.max_buffer_length = max_buffer_length; 54 | } 55 | 56 | /// Set the maximum queued outbound messages for connections created by this server after the setting is applied. 57 | pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) { 58 | self.max_queued_outbound_messages = max_queued_outbound_messages; 59 | } 60 | } 61 | 62 | impl Future for SocketRpcServer 63 | where 64 | TSocketService: SocketService, 65 | { 66 | type Output = Result<(), Error>; 67 | 68 | fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll { 69 | loop { 70 | break match self.listener.poll_accept(context) { 71 | Poll::Ready(result) => match result { 72 | Ok((stream, address)) => { 73 | stream.set_nodelay(true)?; 74 | let (submitter, inbound_messages) = RpcSubmitter::new(); 75 | let (outbound_messages, outbound_messages_receiver) = 76 | mpsc::channel(self.max_queued_outbound_messages); 77 | let connection_service = self.socket_server.new_connection_service(address); 78 | let connection_rpc_server = RpcConnectionServer::new( 79 | connection_service, 80 | inbound_messages, 81 | outbound_messages, 82 | ); 83 | 84 | let connection: Connection> = Connection::new( 85 | stream, 86 | address, 87 | self.socket_server.deserializer(), 88 | self.socket_server.serializer(), 89 | self.max_buffer_length, 90 | self.max_queued_outbound_messages, 91 | outbound_messages_receiver, 92 | submitter, 93 | ); 94 | 95 | tokio::spawn(connection); 96 | tokio::spawn(connection_rpc_server); 97 | 98 | continue; 99 | } 100 | Err(e) => { 101 | log::error!("failed to accept connection: {e:?}"); 102 | continue; 103 | } 104 | }, 105 | Poll::Pending => Poll::Pending, 106 | }; 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /protosocket-server/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "protosocket-server" 3 | description = "Message-oriented nonblocking tcp stream - server tools" 4 | version.workspace = true 5 | edition.workspace = true 6 | license.workspace = true 7 | authors.workspace = true 8 | readme.workspace = true 9 | repository.workspace = true 10 | keywords.workspace = true 11 | categories.workspace = true 12 | 13 | [dependencies] 14 | protosocket = { workspace = true } 15 | 16 | bytes = { workspace = true } 17 | futures = { workspace = true } 18 | log = { workspace = true } 19 | thiserror = { workspace = true } 20 | tokio = { workspace = true } 21 | -------------------------------------------------------------------------------- /protosocket-server/src/connection_server.rs: -------------------------------------------------------------------------------- 1 | use std::future::Future; 2 | use std::io::Error; 3 | use std::pin::Pin; 4 | use std::sync::Arc; 5 | use std::task::Context; 6 | use std::task::Poll; 7 | 8 | use protosocket::Connection; 9 | use protosocket::ConnectionBindings; 10 | use protosocket::Serializer; 11 | use tokio::sync::mpsc; 12 | 13 | pub trait ServerConnector: Unpin { 14 | type Bindings: ConnectionBindings; 15 | 16 | fn serializer(&self) -> ::Serializer; 17 | fn deserializer(&self) -> ::Deserializer; 18 | 19 | fn new_reactor( 20 | &self, 21 | optional_outbound: mpsc::Sender< 22 | <::Serializer as Serializer>::Message, 23 | >, 24 | ) -> ::Reactor; 25 | 26 | fn maximum_message_length(&self) -> usize { 27 | 4 * (2 << 20) 28 | } 29 | 30 | fn max_queued_outbound_messages(&self) -> usize { 31 | 256 32 | } 33 | } 34 | 35 | /// A `protosocket::Connection` is an IO driver. It directly uses tokio's io wrapper of mio to poll 36 | /// the OS's io primitives, manages read and write buffers, and vends messages to & from connections. 37 | /// Connections send messages to the ConnectionServer through an mpsc channel, and they receive 38 | /// inbound messages via a reactor callback. 39 | /// 40 | /// Protosockets are monomorphic messages: You can only have 1 kind of message per service. 41 | /// The expected way to work with this is to use prost and protocol buffers to encode messages. 42 | /// Of course you can do whatever you want, as the telnet example shows. 43 | /// 44 | /// Protosocket messages are not opinionated about request & reply. If you are, you will need 45 | /// to implement such a thing. This allows you freely choose whether you want to send 46 | /// fire-&-forget messages sometimes; however it requires you to write your protocol's rules. 47 | /// You get an inbound iterable of batches and an outbound stream of per 48 | /// connection - you decide what those mean for you! 49 | /// 50 | /// A ProtosocketServer is a future: You spawn it and it runs forever. 51 | pub struct ProtosocketServer { 52 | connector: Connector, 53 | listener: tokio::net::TcpListener, 54 | max_buffer_length: usize, 55 | max_queued_outbound_messages: usize, 56 | runtime: tokio::runtime::Handle, 57 | } 58 | 59 | impl ProtosocketServer { 60 | /// Construct a new `ProtosocketServer` listening on the provided address. 61 | /// The address will be bound and listened upon with `SO_REUSEADDR` set. 62 | /// The server will use the provided runtime to spawn new tcp connections as `protosocket::Connection`s. 63 | pub async fn new( 64 | address: std::net::SocketAddr, 65 | runtime: tokio::runtime::Handle, 66 | connector: Connector, 67 | ) -> crate::Result { 68 | let listener = tokio::net::TcpListener::bind(address) 69 | .await 70 | .map_err(Arc::new)?; 71 | Ok(Self { 72 | connector, 73 | listener, 74 | max_buffer_length: 16 * (2 << 20), 75 | max_queued_outbound_messages: 128, 76 | runtime, 77 | }) 78 | } 79 | 80 | /// Set the maximum buffer length for connections created by this server after the setting is applied. 81 | pub fn set_max_buffer_length(&mut self, max_buffer_length: usize) { 82 | self.max_buffer_length = max_buffer_length; 83 | } 84 | 85 | /// Set the maximum queued outbound messages for connections created by this server after the setting is applied. 86 | pub fn set_max_queued_outbound_messages(&mut self, max_queued_outbound_messages: usize) { 87 | self.max_queued_outbound_messages = max_queued_outbound_messages; 88 | } 89 | } 90 | 91 | impl Future for ProtosocketServer { 92 | type Output = Result<(), Error>; 93 | 94 | fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll { 95 | loop { 96 | break match self.listener.poll_accept(context) { 97 | Poll::Ready(result) => match result { 98 | Ok((stream, address)) => { 99 | stream.set_nodelay(true)?; 100 | let (outbound_submission_queue, outbound_messages) = 101 | mpsc::channel(self.max_queued_outbound_messages); 102 | let reactor = self 103 | .connector 104 | .new_reactor(outbound_submission_queue.clone()); 105 | let connection: Connection = Connection::new( 106 | stream, 107 | address, 108 | self.connector.deserializer(), 109 | self.connector.serializer(), 110 | self.max_buffer_length, 111 | self.max_queued_outbound_messages, 112 | outbound_messages, 113 | reactor, 114 | ); 115 | self.runtime.spawn(connection); 116 | continue; 117 | } 118 | Err(e) => { 119 | log::error!("failed to accept connection: {e:?}"); 120 | continue; 121 | } 122 | }, 123 | Poll::Pending => Poll::Pending, 124 | }; 125 | } 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /protosocket-server/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | /// Result type for protosocket-server. 4 | pub type Result = std::result::Result; 5 | 6 | /// Error type for protosocket-server. 7 | #[derive(Clone, Debug, thiserror::Error)] 8 | pub enum Error { 9 | #[error("IO failure: {0}")] 10 | IoFailure(#[from] Arc), 11 | #[error("Bad address: {0}")] 12 | AddressError(#[from] core::net::AddrParseError), 13 | #[error("Requested resource was dead: ({0})")] 14 | Dead(&'static str), 15 | } 16 | -------------------------------------------------------------------------------- /protosocket-server/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Conveniences for writing protosocket servers. 2 | //! 3 | //! See example-telnet for the simplest full example of the entire workings, 4 | //! or example-proto for an example of how to use this crate with protocol buffers. 5 | 6 | pub(crate) mod connection_server; 7 | pub(crate) mod error; 8 | 9 | pub use connection_server::ProtosocketServer; 10 | pub use connection_server::ServerConnector; 11 | pub use error::Error; 12 | pub use error::Result; 13 | --------------------------------------------------------------------------------