├── .gitignore ├── tests ├── compile_fail │ ├── non_enum.rs │ ├── multiple_fields.rs │ ├── named_enum.rs │ ├── duplicate_type.rs │ ├── non_enum.stderr │ ├── wrong_attr_types.rs │ ├── wrong_attr_types.stderr │ ├── multiple_fields.stderr │ ├── named_enum.stderr │ ├── extra_attr_types.rs │ ├── duplicate_type.stderr │ └── extra_attr_types.stderr ├── common.rs ├── derive.rs ├── oneshot_channel.rs └── mpsc_channel.rs ├── .cargo └── config.toml ├── LICENSE-APACHE ├── irpc-derive ├── Cargo.toml └── src │ └── lib.rs ├── src ├── tests.rs └── util.rs ├── DOCS.md ├── LICENSE-MIT ├── irpc-iroh ├── Cargo.toml ├── examples │ ├── simple.rs │ ├── derive.rs │ ├── auth.rs │ └── 0rtt.rs └── src │ └── lib.rs ├── README.md ├── examples ├── local.rs ├── derive.rs ├── storage.rs └── compute.rs ├── Cargo.toml └── .github └── workflows └── ci.yml /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | .vscode/* 3 | -------------------------------------------------------------------------------- /tests/compile_fail/non_enum.rs: -------------------------------------------------------------------------------- 1 | use irpc::rpc_requests; 2 | 3 | #[rpc_requests(Service, Msg)] 4 | struct Foo; 5 | 6 | fn main() {} -------------------------------------------------------------------------------- /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [target.wasm32-unknown-unknown] 2 | runner = "wasm-bindgen-test-runner" 3 | rustflags = ['--cfg', 'getrandom_backend="wasm_js"'] 4 | -------------------------------------------------------------------------------- /tests/compile_fail/multiple_fields.rs: -------------------------------------------------------------------------------- 1 | use irpc::rpc_requests; 2 | 3 | #[rpc_requests(Service, Msg)] 4 | enum Enum { 5 | A(u8, u8), 6 | } 7 | 8 | fn main() {} -------------------------------------------------------------------------------- /tests/compile_fail/named_enum.rs: -------------------------------------------------------------------------------- 1 | use irpc::rpc_requests; 2 | 3 | #[rpc_requests(Service, Msg)] 4 | enum Enum { 5 | A { name: u8 }, 6 | } 7 | 8 | fn main() {} -------------------------------------------------------------------------------- /tests/compile_fail/duplicate_type.rs: -------------------------------------------------------------------------------- 1 | use irpc::rpc_requests; 2 | 3 | #[rpc_requests(Service, Msg)] 4 | enum Enum { 5 | A(u8), 6 | B(u8), 7 | } 8 | 9 | fn main() {} -------------------------------------------------------------------------------- /tests/compile_fail/non_enum.stderr: -------------------------------------------------------------------------------- 1 | error: RpcRequests can only be applied to enums 2 | --> tests/compile_fail/non_enum.rs:4:1 3 | | 4 | 4 | struct Foo; 5 | | ^^^^^^^^^^^ 6 | -------------------------------------------------------------------------------- /tests/compile_fail/wrong_attr_types.rs: -------------------------------------------------------------------------------- 1 | use irpc::rpc_requests; 2 | 3 | #[rpc_requests(Service, Msg)] 4 | enum Enum { 5 | #[rpc(fnord = Bla)] 6 | A(u8), 7 | } 8 | 9 | fn main() {} -------------------------------------------------------------------------------- /tests/compile_fail/wrong_attr_types.stderr: -------------------------------------------------------------------------------- 1 | error: rpc requires a tx type 2 | --> tests/compile_fail/wrong_attr_types.rs:5:5 3 | | 4 | 5 | #[rpc(fnord = Bla)] 5 | | ^^^^^^^^^^^^^^^^^^^ 6 | -------------------------------------------------------------------------------- /tests/compile_fail/multiple_fields.stderr: -------------------------------------------------------------------------------- 1 | error: Each variant must have exactly one unnamed field 2 | --> tests/compile_fail/multiple_fields.rs:5:5 3 | | 4 | 5 | A(u8, u8), 5 | | ^^^^^^^^^ 6 | -------------------------------------------------------------------------------- /tests/compile_fail/named_enum.stderr: -------------------------------------------------------------------------------- 1 | error: Each variant must have exactly one unnamed field 2 | --> tests/compile_fail/named_enum.rs:5:5 3 | | 4 | 5 | A { name: u8 }, 5 | | ^^^^^^^^^^^^^^ 6 | -------------------------------------------------------------------------------- /tests/compile_fail/extra_attr_types.rs: -------------------------------------------------------------------------------- 1 | use irpc::rpc_requests; 2 | 3 | #[rpc_requests(Service, Msg)] 4 | enum Enum { 5 | #[rpc(tx = NoSender, rx = NoReceiver, fnord = Foo)] 6 | A(u8), 7 | } 8 | 9 | fn main() {} -------------------------------------------------------------------------------- /tests/compile_fail/duplicate_type.stderr: -------------------------------------------------------------------------------- 1 | error: Each variant must have a unique request type 2 | --> tests/compile_fail/duplicate_type.rs:4:1 3 | | 4 | 4 | / enum Enum { 5 | 5 | | A(u8), 6 | 6 | | B(u8), 7 | 7 | | } 8 | | |_^ 9 | -------------------------------------------------------------------------------- /tests/compile_fail/extra_attr_types.stderr: -------------------------------------------------------------------------------- 1 | error: Unknown arguments provided: ["fnord"] 2 | --> tests/compile_fail/extra_attr_types.rs:5:5 3 | | 4 | 5 | #[rpc(tx = NoSender, rx = NoReceiver, fnord = Foo)] 5 | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 6 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at 2 | 3 | http://www.apache.org/licenses/LICENSE-2.0 4 | 5 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- /irpc-derive/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "irpc-derive" 3 | version = "0.9.0" 4 | edition = "2021" 5 | authors = ["Rüdiger Klaehn "] 6 | keywords = ["api", "protocol", "network", "rpc", "macro"] 7 | categories = ["network-programming"] 8 | license = "Apache-2.0/MIT" 9 | repository = "https://github.com/n0-computer/irpc" 10 | description = "Macros for irpc" 11 | 12 | [lib] 13 | proc-macro = true 14 | 15 | [dependencies] 16 | syn = { version = "2", features = ["full"] } 17 | quote = "1" 18 | proc-macro2 = "1" 19 | -------------------------------------------------------------------------------- /src/tests.rs: -------------------------------------------------------------------------------- 1 | use std::vec; 2 | 3 | #[tokio::test] 4 | async fn test_map_filter() { 5 | use crate::channel::mpsc; 6 | let (tx, rx) = mpsc::channel::(100); 7 | // *2, filter multipes of 4, *3 if multiple of 8 8 | // 9 | // the transforms are applied in reverse order! 10 | let tx = tx 11 | .with_filter_map(|x: u64| if x % 8 == 0 { Some(x * 3) } else { None }) 12 | .with_filter(|x| x % 4 == 0) 13 | .with_map(|x: u64| x * 2); 14 | for i in 0..100 { 15 | tx.send(i).await.ok(); 16 | } 17 | drop(tx); 18 | // /24, filter multiples of 3, /2 if even 19 | let mut rx = rx 20 | .map(|x: u64| x / 24) 21 | .filter(|x| x % 3 == 0) 22 | .filter_map(|x: u64| if x % 2 == 0 { Some(x / 2) } else { None }); 23 | let mut res = vec![]; 24 | while let Ok(Some(x)) = rx.recv().await { 25 | res.push(x); 26 | } 27 | assert_eq!(res, vec![0, 3, 6, 9, 12]); 28 | } 29 | -------------------------------------------------------------------------------- /DOCS.md: -------------------------------------------------------------------------------- 1 | Building docs for this crate is a bit complex. There are lots of feature flags, 2 | so we want feature flag markers in the docs, especially for the transports. 3 | 4 | There is an experimental cargo doc feature that adds feature flag markers. To 5 | get those, run docs with this command line: 6 | 7 | ```rust 8 | RUSTDOCFLAGS="--cfg quicrpc_docsrs" cargo +nightly doc --all-features --no-deps --open 9 | ``` 10 | 11 | This sets the flag `quicrpc_docsrs` when creating docs, which triggers statements 12 | like below that add feature flag markers. Note that you *need* nightly for this feature 13 | as of now. 14 | 15 | ``` 16 | #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "flume-transport")))] 17 | ``` 18 | 19 | The feature is *enabled* using this statement in lib.rs: 20 | 21 | ``` 22 | #![cfg_attr(quicrpc_docsrs, feature(doc_cfg))] 23 | ``` 24 | 25 | We tell [docs.rs] to use the `quicrpc_docsrs` config using these statements 26 | in Cargo.toml: 27 | 28 | ``` 29 | [package.metadata.docs.rs] 30 | all-features = true 31 | rustdoc-args = ["--cfg", "quicrpc_docsrs"] 32 | ``` -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /irpc-iroh/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "irpc-iroh" 3 | version = "0.11.0" 4 | edition = "2021" 5 | authors = ["Rüdiger Klaehn ", "n0 team"] 6 | keywords = ["api", "protocol", "network", "rpc"] 7 | categories = ["network-programming"] 8 | license = "Apache-2.0/MIT" 9 | repository = "https://github.com/n0-computer/irpc" 10 | description = "Iroh transport for irpc" 11 | 12 | [lib] 13 | crate-type = ["cdylib", "rlib"] 14 | 15 | [dependencies] 16 | iroh = { workspace = true } 17 | tokio = { workspace = true, features = ["sync"] } 18 | tracing = { workspace = true } 19 | serde = { workspace = true } 20 | postcard = { workspace = true, features = ["alloc", "use-std"] } 21 | n0-error = { workspace = true } 22 | n0-future = { workspace = true } 23 | irpc = { version = "0.11.0", path = ".." } 24 | iroh-base.workspace = true 25 | 26 | [target.'cfg(all(target_family = "wasm", target_os = "unknown"))'.dependencies] 27 | getrandom = { version = "0.3", features = ["wasm_js"] } 28 | 29 | [dev-dependencies] 30 | n0-future = { workspace = true } 31 | tracing-subscriber = { workspace = true, features = ["fmt"] } 32 | irpc-derive = { version = "0.9.0", path = "../irpc-derive" } 33 | clap = { version = "4.5.41", features = ["derive"] } 34 | futures-util.workspace = true 35 | hex = "0.4.3" 36 | rand = "0.9.2" 37 | anyhow = { workspace = true } 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IRPC 2 | 3 | A streaming rpc system for iroh 4 | 5 | [][repo link] [![Latest Version]][crates.io] [![Docs Badge]][docs.rs] ![license badge] [![status badge]][status link] 6 | 7 | [Latest Version]: https://img.shields.io/crates/v/irpc.svg 8 | [crates.io]: https://crates.io/crates/irpc 9 | [Docs Badge]: https://img.shields.io/badge/docs-docs.rs-green 10 | [docs.rs]: https://docs.rs/irpc 11 | [license badge]: https://img.shields.io/crates/l/irpc 12 | [status badge]: https://github.com/n0-computer/irpc/actions/workflows/rust.yml/badge.svg 13 | [status link]: https://github.com/n0-computer/irpc/actions/workflows/ci.yml 14 | [repo link]: https://github.com/n0-computer/irpc 15 | 16 | # Goals 17 | 18 | See the [module docs](https://docs.rs/irpc/latest/irpc/). 19 | 20 | # Docs 21 | 22 | Properly building docs for this crate is quite complex. For all the gory details, 23 | see [DOCS.md]. 24 | 25 | ## License 26 | 27 | Copyright 2025 N0, INC. 28 | 29 | This project is licensed under either of 30 | 31 | * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or 32 | http://www.apache.org/licenses/LICENSE-2.0) 33 | * MIT license ([LICENSE-MIT](LICENSE-MIT) or 34 | http://opensource.org/licenses/MIT) 35 | 36 | at your option. 37 | 38 | ## Contribution 39 | 40 | Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in this project by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. 41 | -------------------------------------------------------------------------------- /tests/common.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "quinn_endpoint_setup")] 2 | 3 | use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; 4 | 5 | use irpc::util::{make_client_endpoint, make_server_endpoint}; 6 | use n0_error::stack_error; 7 | use quinn::Endpoint; 8 | use serde::{Deserialize, Deserializer, Serialize, Serializer}; 9 | use testresult::TestResult; 10 | 11 | pub fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { 12 | let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); 13 | let (server, cert) = make_server_endpoint(addr)?; 14 | let client = make_client_endpoint(addr, &[cert.as_slice()])?; 15 | let port = server.local_addr()?.port(); 16 | let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); 17 | Ok((server, client, server_addr)) 18 | } 19 | 20 | #[derive(Debug)] 21 | pub struct NoSer(pub u64); 22 | 23 | #[stack_error(derive)] 24 | #[error("Cannot serialize odd number")] 25 | pub struct OddNumberError(u64); 26 | 27 | impl Serialize for NoSer { 28 | fn serialize(&self, serializer: S) -> Result 29 | where 30 | S: Serializer, 31 | { 32 | if self.0 % 2 == 1 { 33 | Err(serde::ser::Error::custom(OddNumberError(self.0))) 34 | } else { 35 | serializer.serialize_u64(self.0) 36 | } 37 | } 38 | } 39 | 40 | impl<'de> Deserialize<'de> for NoSer { 41 | fn deserialize(deserializer: D) -> Result 42 | where 43 | D: Deserializer<'de>, 44 | { 45 | let value = u64::deserialize(deserializer)?; 46 | if value % 2 != 0 { 47 | Err(serde::de::Error::custom(OddNumberError(value))) 48 | } else { 49 | Ok(NoSer(value)) 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /tests/derive.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "derive")] 2 | 3 | use irpc::{ 4 | channel::{none::NoSender, oneshot}, 5 | rpc_requests, 6 | }; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | #[test] 10 | fn derive_simple() { 11 | #![allow(dead_code)] 12 | 13 | #[derive(Debug, Serialize, Deserialize)] 14 | struct RpcRequest; 15 | 16 | #[derive(Debug, Serialize, Deserialize)] 17 | struct ServerStreamingRequest; 18 | 19 | #[derive(Debug, Serialize, Deserialize)] 20 | struct ClientStreamingRequest; 21 | 22 | #[derive(Debug, Serialize, Deserialize)] 23 | struct BidiStreamingRequest; 24 | 25 | #[derive(Debug, Serialize, Deserialize)] 26 | struct Update1; 27 | 28 | #[derive(Debug, Serialize, Deserialize)] 29 | struct Update2; 30 | 31 | #[derive(Debug, Serialize, Deserialize)] 32 | struct Response1; 33 | 34 | #[derive(Debug, Serialize, Deserialize)] 35 | struct Response2; 36 | 37 | #[derive(Debug, Serialize, Deserialize)] 38 | struct Response3; 39 | 40 | #[derive(Debug, Serialize, Deserialize)] 41 | struct Response4; 42 | 43 | #[rpc_requests(message = RequestWithChannels, no_rpc, no_spans)] 44 | #[derive(Debug, Serialize, Deserialize)] 45 | enum Request { 46 | #[rpc(tx=oneshot::Sender<()>)] 47 | Rpc(RpcRequest), 48 | #[rpc(tx=NoSender)] 49 | ServerStreaming(ServerStreamingRequest), 50 | #[rpc(tx=NoSender)] 51 | BidiStreaming(BidiStreamingRequest), 52 | #[rpc(tx=NoSender)] 53 | ClientStreaming(ClientStreamingRequest), 54 | } 55 | } 56 | 57 | /// Use 58 | /// 59 | /// TRYBUILD=overwrite cargo test --test smoke 60 | /// 61 | /// to update the snapshots 62 | #[test] 63 | #[ignore = "stupid diffs depending on rustc version"] 64 | fn compile_fail() { 65 | let t = trybuild::TestCases::new(); 66 | t.compile_fail("tests/compile_fail/*.rs"); 67 | } 68 | -------------------------------------------------------------------------------- /examples/local.rs: -------------------------------------------------------------------------------- 1 | //! This demonstrates using irpc with the derive macro but without the rpc feature 2 | //! for local-only use. Run with: 3 | //! ``` 4 | //! cargo run --example local --no-default-features --features derive 5 | //! ``` 6 | 7 | use std::collections::BTreeMap; 8 | 9 | use irpc::{channel::oneshot, rpc_requests, Client, WithChannels}; 10 | use serde::{Deserialize, Serialize}; 11 | 12 | #[derive(Debug, Serialize, Deserialize)] 13 | struct Get { 14 | key: String, 15 | } 16 | 17 | #[derive(Debug, Serialize, Deserialize)] 18 | struct Set { 19 | key: String, 20 | value: String, 21 | } 22 | 23 | impl From<(String, String)> for Set { 24 | fn from((key, value): (String, String)) -> Self { 25 | Self { key, value } 26 | } 27 | } 28 | 29 | #[rpc_requests(message = StorageMessage, no_rpc, no_spans)] 30 | #[derive(Serialize, Deserialize, Debug)] 31 | enum StorageProtocol { 32 | #[rpc(tx=oneshot::Sender>)] 33 | Get(Get), 34 | #[rpc(tx=oneshot::Sender<()>)] 35 | Set(Set), 36 | } 37 | 38 | struct Actor { 39 | recv: tokio::sync::mpsc::Receiver, 40 | state: BTreeMap, 41 | } 42 | 43 | impl Actor { 44 | async fn run(mut self) { 45 | while let Some(msg) = self.recv.recv().await { 46 | self.handle(msg).await; 47 | } 48 | } 49 | 50 | async fn handle(&mut self, msg: StorageMessage) { 51 | match msg { 52 | StorageMessage::Get(get) => { 53 | let WithChannels { tx, inner, .. } = get; 54 | tx.send(self.state.get(&inner.key).cloned()).await.ok(); 55 | } 56 | StorageMessage::Set(set) => { 57 | let WithChannels { tx, inner, .. } = set; 58 | self.state.insert(inner.key, inner.value); 59 | tx.send(()).await.ok(); 60 | } 61 | } 62 | } 63 | } 64 | 65 | struct StorageApi { 66 | inner: Client, 67 | } 68 | 69 | impl StorageApi { 70 | pub fn spawn() -> StorageApi { 71 | let (tx, rx) = tokio::sync::mpsc::channel(1); 72 | let actor = Actor { 73 | recv: rx, 74 | state: BTreeMap::new(), 75 | }; 76 | n0_future::task::spawn(actor.run()); 77 | StorageApi { 78 | inner: Client::local(tx), 79 | } 80 | } 81 | 82 | pub async fn get(&self, key: String) -> irpc::Result> { 83 | self.inner.rpc(Get { key }).await 84 | } 85 | 86 | pub async fn set(&self, key: String, value: String) -> irpc::Result<()> { 87 | self.inner.rpc(Set { key, value }).await 88 | } 89 | } 90 | 91 | #[tokio::main] 92 | async fn main() -> irpc::Result<()> { 93 | tracing_subscriber::fmt::init(); 94 | let api = StorageApi::spawn(); 95 | api.set("hello".to_string(), "world".to_string()).await?; 96 | let value = api.get("hello".to_string()).await?; 97 | println!("get: hello = {value:?}"); 98 | Ok(()) 99 | } 100 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "irpc" 3 | version = "0.11.0" 4 | edition = "2021" 5 | authors = ["Rüdiger Klaehn ", "n0 team"] 6 | keywords = ["api", "protocol", "network", "rpc"] 7 | categories = ["network-programming"] 8 | license = "Apache-2.0/MIT" 9 | repository = "https://github.com/n0-computer/irpc" 10 | description = "A streaming rpc system based on quic" 11 | 12 | # Sadly this also needs to be updated in .github/workflows/ci.yml 13 | rust-version = "1.76" 14 | 15 | [lib] 16 | crate-type = ["cdylib", "rlib"] 17 | 18 | [dependencies] 19 | # we require serde even in non-rpc mode 20 | serde = { workspace = true } 21 | # just for the oneshot and mpsc queues, and tokio::select! 22 | tokio = { workspace = true, features = ["sync", "macros"] } 23 | # for PollSender (which for some reason is not available in the main tokio api) 24 | tokio-util = { version = "0.7.14", default-features = false } 25 | # errors 26 | n0-error = { workspace = true } 27 | 28 | # used in the endpoint handler code when using rpc 29 | tracing = { workspace = true, optional = true } 30 | # used to ser/de messages when using rpc 31 | postcard = { workspace = true, features = ["alloc", "use-std"], optional = true } 32 | # currently only transport when using rpc 33 | quinn = { workspace = true, optional = true } 34 | # used as a buffer for serialization when using rpc 35 | smallvec = { version = "1.14.0", features = ["write"], optional = true } 36 | # used in the test utils to generate quinn endpoints 37 | rustls = { version = "0.23.5", default-features = false, features = ["std"], optional = true } 38 | # used in the test utils to generate quinn endpoints 39 | rcgen = { version = "0.14.5", optional = true } 40 | # used in the benches 41 | futures-buffered ={ version = "0.2.9", optional = true } 42 | # for AbortOnDropHandle 43 | n0-future = { workspace = true } 44 | futures-util = { workspace = true, optional = true } 45 | # for the derive reexport/feature 46 | irpc-derive = { version = "0.9.0", path = "./irpc-derive", optional = true } 47 | 48 | [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] 49 | quinn = { workspace = true, optional = true, features = ["runtime-tokio"] } 50 | 51 | [dev-dependencies] 52 | tracing-subscriber = { workspace = true, features = ["fmt"] } 53 | # just convenient for the enum definitions, in the manual example 54 | derive_more = { version = "2", features = ["from"] } 55 | # we need full for example main etc. 56 | tokio = { workspace = true, features = ["full"] } 57 | # formatting 58 | thousands = "0.2.0" 59 | # macro tests 60 | trybuild = "1.0.104" 61 | testresult = "0.4.1" 62 | # used in examples 63 | anyhow = { workspace = true } 64 | 65 | [features] 66 | # enable the remote transport 67 | rpc = ["dep:quinn", "dep:postcard", "dep:smallvec", "dep:tracing", "tokio/io-util"] 68 | # add test utilities 69 | quinn_endpoint_setup = ["rpc", "dep:rustls", "dep:rcgen", "dep:futures-buffered", "quinn/rustls-ring"] 70 | # pick up parent span when creating channel messages 71 | spans = ["dep:tracing"] 72 | stream = ["dep:futures-util"] 73 | derive = ["dep:irpc-derive"] 74 | varint-util = ["dep:postcard", "dep:smallvec", "tokio/io-util"] 75 | default = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"] 76 | 77 | [[example]] 78 | name = "derive" 79 | required-features = ["rpc", "derive", "quinn_endpoint_setup"] 80 | 81 | [[example]] 82 | name = "compute" 83 | required-features = ["rpc", "derive", "quinn_endpoint_setup"] 84 | 85 | [[example]] 86 | name = "local" 87 | required-features = ["derive"] 88 | 89 | [[example]] 90 | name = "storage" 91 | required-features = ["rpc", "quinn_endpoint_setup"] 92 | 93 | [workspace] 94 | members = ["irpc-derive", "irpc-iroh"] 95 | 96 | [package.metadata.docs.rs] 97 | all-features = true 98 | rustdoc-args = ["--cfg", "quicrpc_docsrs"] 99 | 100 | [lints.rust] 101 | unexpected_cfgs = { level = "warn", check-cfg = ["cfg(quicrpc_docsrs)"] } 102 | 103 | [workspace.dependencies] 104 | anyhow = { version = "1" } 105 | tokio = { version = "1.44", default-features = false } 106 | postcard = { version = "1.1.1", default-features = false } 107 | serde = { version = "1", default-features = false, features = ["derive"] } 108 | tracing = { version = "0.1.41", default-features = false } 109 | n0-future = { version = "0.3", default-features = false } 110 | n0-error = { version = "0.1.0" } 111 | tracing-subscriber = { version = "0.3.20" } 112 | iroh = { version = "0.95" } 113 | iroh-base = { version = "0.95" } 114 | quinn = { package = "iroh-quinn", version = "0.14.0", default-features = false } 115 | futures-util = { version = "0.3", features = ["sink"] } 116 | -------------------------------------------------------------------------------- /irpc-iroh/examples/simple.rs: -------------------------------------------------------------------------------- 1 | #[tokio::main] 2 | async fn main() -> anyhow::Result<()> { 3 | cli::run().await 4 | } 5 | 6 | mod proto { 7 | use std::collections::HashMap; 8 | 9 | use anyhow::Result; 10 | use iroh::{protocol::Router, Endpoint, EndpointId}; 11 | use irpc::{channel::oneshot, rpc_requests, Client, WithChannels}; 12 | use irpc_iroh::IrohProtocol; 13 | use serde::{Deserialize, Serialize}; 14 | 15 | const ALPN: &[u8] = b"iroh-irpc/simple/1"; 16 | 17 | #[rpc_requests(message = FooMessage)] 18 | #[derive(Debug, Serialize, Deserialize)] 19 | pub enum FooProtocol { 20 | /// This is the get request. 21 | #[rpc(tx=oneshot::Sender>)] 22 | #[wrap(GetRequest, derive(Clone))] 23 | Get(String), 24 | 25 | /// This is the set request. 26 | #[rpc(tx=oneshot::Sender>)] 27 | #[wrap(SetRequest)] 28 | Set { 29 | /// This is the key 30 | key: String, 31 | /// This is the value 32 | value: String, 33 | }, 34 | } 35 | 36 | pub async fn listen() -> Result<()> { 37 | let (tx, rx) = tokio::sync::mpsc::channel(16); 38 | tokio::task::spawn(actor(rx)); 39 | let client = Client::::local(tx); 40 | 41 | let endpoint = Endpoint::bind().await?; 42 | let protocol = IrohProtocol::with_sender(client.as_local().unwrap()); 43 | let router = Router::builder(endpoint).accept(ALPN, protocol).spawn(); 44 | println!("endpoint id: {}", router.endpoint().id()); 45 | 46 | tokio::signal::ctrl_c().await?; 47 | router.shutdown().await?; 48 | Ok(()) 49 | } 50 | 51 | async fn actor(mut rx: tokio::sync::mpsc::Receiver) { 52 | let mut store = HashMap::new(); 53 | while let Some(msg) = rx.recv().await { 54 | match msg { 55 | FooMessage::Get(msg) => { 56 | let WithChannels { inner, tx, .. } = msg; 57 | println!("handle request: {inner:?}"); 58 | 59 | // We can clone `inner` because we added the `Clone` derive to the `wrap` attribute: 60 | let _ = inner.clone(); 61 | 62 | let GetRequest(key) = inner; 63 | let value = store.get(&key).cloned(); 64 | tx.send(value).await.ok(); 65 | } 66 | FooMessage::Set(msg) => { 67 | let WithChannels { inner, tx, .. } = msg; 68 | println!("handle request: {inner:?}"); 69 | let SetRequest { key, value } = inner; 70 | let prev_value = store.insert(key, value); 71 | tx.send(prev_value).await.ok(); 72 | } 73 | } 74 | } 75 | } 76 | 77 | pub async fn connect(endpoint_id: EndpointId) -> Result> { 78 | println!("connecting to {endpoint_id}"); 79 | let endpoint = Endpoint::bind().await?; 80 | let client = irpc_iroh::client(endpoint, endpoint_id, ALPN); 81 | Ok(client) 82 | } 83 | } 84 | 85 | mod cli { 86 | use anyhow::Result; 87 | use clap::Parser; 88 | use iroh::EndpointId; 89 | 90 | use crate::proto::{connect, listen, GetRequest, SetRequest}; 91 | 92 | #[derive(Debug, Parser)] 93 | enum Cli { 94 | Listen, 95 | Connect { 96 | endpoint_id: EndpointId, 97 | #[clap(subcommand)] 98 | command: Command, 99 | }, 100 | } 101 | 102 | #[derive(Debug, Parser)] 103 | enum Command { 104 | Get { key: String }, 105 | Set { key: String, value: String }, 106 | } 107 | 108 | pub async fn run() -> Result<()> { 109 | match Cli::parse() { 110 | Cli::Listen => listen().await?, 111 | Cli::Connect { 112 | endpoint_id, 113 | command, 114 | } => { 115 | let client = connect(endpoint_id).await?; 116 | match command { 117 | Command::Get { key } => { 118 | println!("get '{key}'"); 119 | let value = client.rpc(GetRequest(key)).await?; 120 | println!("{value:?}"); 121 | } 122 | Command::Set { key, value } => { 123 | println!("set '{key}' to '{value}'"); 124 | let value = client.rpc(SetRequest { key, value }).await?; 125 | println!("OK (previous: {value:?})"); 126 | } 127 | } 128 | } 129 | } 130 | Ok(()) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /tests/oneshot_channel.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "quinn_endpoint_setup")] 2 | 3 | use std::io::{self, ErrorKind}; 4 | 5 | use irpc::{ 6 | channel::{ 7 | oneshot::{self, RecvError}, 8 | SendError, 9 | }, 10 | util::AsyncWriteVarintExt, 11 | }; 12 | use n0_error::e; 13 | use quinn::Endpoint; 14 | use testresult::TestResult; 15 | 16 | mod common; 17 | use common::*; 18 | 19 | async fn vec_receiver(server: Endpoint) -> Result<(), RecvError> { 20 | let conn = server 21 | .accept() 22 | .await 23 | .unwrap() 24 | .await 25 | .map_err(|err| e!(RecvError::Io, err.into()))?; 26 | let (_, recv) = conn 27 | .accept_bi() 28 | .await 29 | .map_err(|err| e!(RecvError::Io, err.into()))?; 30 | let recv = oneshot::Receiver::>::from(recv); 31 | recv.await?; 32 | Err(e!(RecvError::Io, io::ErrorKind::UnexpectedEof.into())) 33 | } 34 | 35 | /// Checks that the max message size is enforced on the sender side and that errors are propagated to the receiver side. 36 | #[tokio::test] 37 | async fn oneshot_max_message_size_send() -> TestResult<()> { 38 | let (server, client, server_addr) = create_connected_endpoints()?; 39 | let server = tokio::spawn(vec_receiver(server)); 40 | let conn = client.connect(server_addr, "localhost")?.await?; 41 | let (send, _) = conn.open_bi().await?; 42 | let send = oneshot::Sender::>::from(send); 43 | // this one should fail! 44 | let Err(cause) = send.send(vec![0u8; 1024 * 1024 * 32]).await else { 45 | panic!("client should have failed due to max message size"); 46 | }; 47 | assert!(matches!(cause, SendError::MaxMessageSizeExceeded { .. })); 48 | let Err(cause) = server.await? else { 49 | panic!("server should have failed due to max message size"); 50 | }; 51 | assert!( 52 | matches!(cause, RecvError::Io { source, .. } if source.kind() == ErrorKind::ConnectionReset) 53 | ); 54 | Ok(()) 55 | } 56 | 57 | /// Checks that the max message size is enforced on receiver side. 58 | #[tokio::test] 59 | async fn oneshot_max_message_size_recv() -> TestResult<()> { 60 | let (server, client, server_addr) = create_connected_endpoints()?; 61 | let server = tokio::spawn(vec_receiver(server)); 62 | let conn = client.connect(server_addr, "localhost")?.await?; 63 | let (mut send, _) = conn.open_bi().await?; 64 | // this one should fail on receive! 65 | send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]) 66 | .await 67 | .ok(); 68 | let Err(cause) = server.await? else { 69 | panic!("server should have failed due to max message size"); 70 | }; 71 | assert!(matches!(cause, RecvError::MaxMessageSizeExceeded { .. })); 72 | Ok(()) 73 | } 74 | 75 | async fn noser_receiver(server: Endpoint) -> Result<(), RecvError> { 76 | let conn = server 77 | .accept() 78 | .await 79 | .unwrap() 80 | .await 81 | .map_err(|err| e!(RecvError::Io, err.into()))?; 82 | let (_, recv) = conn 83 | .accept_bi() 84 | .await 85 | .map_err(|err| e!(RecvError::Io, err.into()))?; 86 | let recv = oneshot::Receiver::::from(recv); 87 | recv.await?; 88 | Err(e!(RecvError::Io, io::ErrorKind::UnexpectedEof.into())) 89 | } 90 | 91 | /// Checks that trying to send a message that cannot be serialized results in an error on the sender side and a connection reset on the receiver side. 92 | #[tokio::test] 93 | async fn oneshot_serialize_error_send() -> TestResult<()> { 94 | let (server, client, server_addr) = create_connected_endpoints()?; 95 | let server = tokio::spawn(noser_receiver(server)); 96 | let conn = client.connect(server_addr, "localhost")?.await?; 97 | let (send, _) = conn.open_bi().await?; 98 | let send = oneshot::Sender::::from(send); 99 | // this one should fail! 100 | let Err(cause) = send.send(NoSer(1)).await else { 101 | panic!("client should have failed due to serialization error"); 102 | }; 103 | assert!( 104 | matches!(cause, SendError::Io { source, .. } if source.kind() == ErrorKind::InvalidData) 105 | ); 106 | let Err(cause) = server.await? else { 107 | panic!("server should have failed due to serialization error"); 108 | }; 109 | println!("Server error: {cause:?}"); 110 | assert!( 111 | matches!(cause, RecvError::Io { source, .. } if source.kind() == ErrorKind::ConnectionReset) 112 | ); 113 | Ok(()) 114 | } 115 | 116 | #[tokio::test] 117 | async fn oneshot_serialize_error_recv() -> TestResult<()> { 118 | let (server, client, server_addr) = create_connected_endpoints()?; 119 | let server = tokio::spawn(noser_receiver(server)); 120 | let conn = client.connect(server_addr, "localhost")?.await?; 121 | let (mut send, _) = conn.open_bi().await?; 122 | // this one should fail on receive! 123 | send.write_length_prefixed(1u64).await?; 124 | send.finish()?; 125 | let Err(cause) = server.await? else { 126 | panic!("server should have failed due to serialization error"); 127 | }; 128 | println!("Server error: {cause:?}"); 129 | assert!( 130 | matches!(cause, RecvError::Io { source, .. } if source.kind() == ErrorKind::InvalidData) 131 | ); 132 | Ok(()) 133 | } 134 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | # This helps fill the caches properly, caches are not shared between PRs. 7 | - main 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | concurrency: 13 | group: tests-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 14 | cancel-in-progress: true 15 | 16 | env: 17 | MSRV: "1.76" 18 | RUST_BACKTRACE: 1 19 | RUSTFLAGS: -Dwarnings 20 | 21 | jobs: 22 | lint: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v2 26 | - uses: dtolnay/rust-toolchain@stable 27 | with: 28 | components: rustfmt, clippy 29 | - uses: swatinem/rust-cache@v2 30 | - name: cargo fmt 31 | run: cargo fmt --all -- --check 32 | - name: cargo clippy 33 | run: cargo clippy --locked --workspace --all-targets --all-features 34 | 35 | test: 36 | runs-on: ${{ matrix.target.os }} 37 | strategy: 38 | fail-fast: false 39 | matrix: 40 | target: 41 | - os: "ubuntu-latest" 42 | toolchain: "x86_64-unknown-linux-gnu" 43 | name: "Linux GNU" 44 | - os: "macOS-latest" 45 | toolchain: "x86_64-apple-darwin" 46 | name: "macOS" 47 | - os: "windows-latest" 48 | toolchain: "x86_64-pc-windows-msvc" 49 | name: "Windows MSVC" 50 | - os: "windows-latest" 51 | toolchain: "x86_64-pc-windows-gnu" 52 | name: "Windows GNU" 53 | channel: 54 | - "stable" 55 | steps: 56 | - uses: actions/checkout@v2 57 | - uses: dtolnay/rust-toolchain@master 58 | with: 59 | toolchain: ${{ matrix.channel }} 60 | targets: ${{ matrix.target.toolchain }} 61 | - uses: swatinem/rust-cache@v2 62 | - name: cargo test (workspace, all features) 63 | run: cargo test --locked --workspace --all-features --bins --tests --examples 64 | - name: cargo test (workspace, default features) 65 | run: cargo test --locked --workspace --bins --tests --examples 66 | - name: cargo test (workspace, no default features) 67 | run: cargo test --locked --workspace --no-default-features --bins --tests --examples 68 | - name: cargo check (irpc, no default features) 69 | run: cargo check --locked --no-default-features --bins --tests --examples 70 | - name: cargo check (irpc, feature derive) 71 | run: cargo check --locked --no-default-features --features derive --bins --tests --examples 72 | - name: cargo check (irpc, feature spans) 73 | run: cargo check --locked --no-default-features --features spans --bins --tests --examples 74 | - name: cargo check (irpc, feature rpc) 75 | run: cargo check --locked --no-default-features --features rpc --bins --tests --examples 76 | 77 | test-release: 78 | runs-on: ${{ matrix.target.os }} 79 | strategy: 80 | fail-fast: false 81 | matrix: 82 | target: 83 | - os: "ubuntu-latest" 84 | toolchain: "x86_64-unknown-linux-gnu" 85 | name: "Linux GNU" 86 | - os: "macOS-latest" 87 | toolchain: "x86_64-apple-darwin" 88 | name: "macOS" 89 | - os: "windows-latest" 90 | toolchain: "x86_64-pc-windows-msvc" 91 | name: "Windows MSVC" 92 | - os: "windows-latest" 93 | toolchain: "x86_64-pc-windows-gnu" 94 | name: "Windows GNU" 95 | channel: 96 | - "stable" 97 | - "beta" 98 | steps: 99 | - uses: actions/checkout@v2 100 | - uses: dtolnay/rust-toolchain@master 101 | with: 102 | toolchain: ${{ matrix.channel }} 103 | targets: ${{ matrix.target.toolchain }} 104 | - uses: swatinem/rust-cache@v2 105 | - name: cargo test 106 | run: cargo test --release --locked --workspace --all-features --bins --tests --examples 107 | 108 | wasm_build: 109 | name: Build wasm32 110 | runs-on: ubuntu-latest 111 | env: 112 | RUSTFLAGS: '--cfg getrandom_backend="wasm_js"' 113 | steps: 114 | - name: Checkout sources 115 | uses: actions/checkout@v4 116 | - name: Install stable toolchain 117 | uses: dtolnay/rust-toolchain@stable 118 | - name: Add wasm target 119 | run: rustup target add wasm32-unknown-unknown 120 | - name: Install wasm-tools 121 | uses: bytecodealliance/actions/wasm-tools/setup@v1 122 | - name: wasm32 build 123 | run: cargo build --target wasm32-unknown-unknown --all 124 | # If the Wasm file contains any 'import "env"' declarations, then 125 | # some non-Wasm-compatible code made it into the final code. 126 | - name: Ensure no 'import "env"' in wasm 127 | run: | 128 | ! wasm-tools print --skeleton target/wasm32-unknown-unknown/debug/irpc.wasm | grep 'import "env"' 129 | ! wasm-tools print --skeleton target/wasm32-unknown-unknown/debug/irpc_iroh.wasm | grep 'import "env"' 130 | 131 | # Checks correct runtime deps and features are requested by not including dev-dependencies. 132 | check-deps: 133 | runs-on: ubuntu-latest 134 | steps: 135 | - uses: actions/checkout@v2 136 | - uses: dtolnay/rust-toolchain@stable 137 | - uses: swatinem/rust-cache@v2 138 | - name: cargo check 139 | run: cargo check --workspace --all-features --lib --bins 140 | 141 | minimal-crates: 142 | runs-on: ubuntu-latest 143 | steps: 144 | - uses: actions/checkout@v2 145 | - uses: dtolnay/rust-toolchain@nightly 146 | - uses: swatinem/rust-cache@v2 147 | - name: cargo check 148 | run: | 149 | rm -f Cargo.lock 150 | cargo +nightly check -Z minimal-versions -p irpc -p irpc-derive --all-features --lib --bins 151 | -------------------------------------------------------------------------------- /irpc-iroh/examples/derive.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use iroh::{protocol::Router, Endpoint}; 3 | 4 | use self::storage::StorageApi; 5 | 6 | #[tokio::main] 7 | async fn main() -> Result<()> { 8 | tracing_subscriber::fmt::init(); 9 | println!("Local use"); 10 | local().await?; 11 | println!("Remote use"); 12 | remote().await?; 13 | Ok(()) 14 | } 15 | 16 | async fn local() -> Result<()> { 17 | let api = StorageApi::spawn(); 18 | api.set("hello".to_string(), "world".to_string()).await?; 19 | let value = api.get("hello".to_string()).await?; 20 | let mut list = api.list().await?; 21 | while let Some(value) = list.recv().await? { 22 | println!("list value = {value:?}"); 23 | } 24 | println!("value = {value:?}"); 25 | Ok(()) 26 | } 27 | 28 | async fn remote() -> Result<()> { 29 | let (server_router, server_addr) = { 30 | let endpoint = Endpoint::bind().await?; 31 | let api = StorageApi::spawn(); 32 | let router = Router::builder(endpoint.clone()) 33 | .accept(StorageApi::ALPN, api.expose()?) 34 | .spawn(); 35 | let addr = endpoint.addr(); 36 | (router, addr) 37 | }; 38 | 39 | let client_endpoint = Endpoint::builder().bind().await?; 40 | let api = StorageApi::connect(client_endpoint, server_addr)?; 41 | api.set("hello".to_string(), "world".to_string()).await?; 42 | api.set("goodbye".to_string(), "world".to_string()).await?; 43 | let value = api.get("hello".to_string()).await?; 44 | println!("value = {value:?}"); 45 | let mut list = api.list().await?; 46 | while let Some(value) = list.recv().await? { 47 | println!("list value = {value:?}"); 48 | } 49 | drop(server_router); 50 | Ok(()) 51 | } 52 | 53 | mod storage { 54 | //! Implementation of our storage service. 55 | //! 56 | //! The only `pub` item is [`StorageApi`], everything else is private. 57 | 58 | use std::collections::BTreeMap; 59 | 60 | use anyhow::{Context, Result}; 61 | use iroh::{protocol::ProtocolHandler, Endpoint}; 62 | use irpc::{ 63 | channel::{mpsc, oneshot}, 64 | rpc::RemoteService, 65 | rpc_requests, Client, WithChannels, 66 | }; 67 | // Import the macro 68 | use irpc_iroh::{IrohLazyRemoteConnection, IrohProtocol}; 69 | use serde::{Deserialize, Serialize}; 70 | use tracing::info; 71 | 72 | #[derive(Debug, Serialize, Deserialize)] 73 | struct Get { 74 | key: String, 75 | } 76 | 77 | #[derive(Debug, Serialize, Deserialize)] 78 | struct List; 79 | 80 | #[derive(Debug, Serialize, Deserialize)] 81 | struct Set { 82 | key: String, 83 | value: String, 84 | } 85 | 86 | // Use the macro to generate both the StorageProtocol and StorageMessage enums 87 | // plus implement Channels for each type 88 | #[rpc_requests(message = StorageMessage)] 89 | #[derive(Serialize, Deserialize, Debug)] 90 | enum StorageProtocol { 91 | #[rpc(tx=oneshot::Sender>)] 92 | Get(Get), 93 | #[rpc(tx=oneshot::Sender<()>)] 94 | Set(Set), 95 | #[rpc(tx=mpsc::Sender)] 96 | List(List), 97 | } 98 | 99 | struct StorageActor { 100 | recv: tokio::sync::mpsc::Receiver, 101 | state: BTreeMap, 102 | } 103 | 104 | impl StorageActor { 105 | pub fn spawn() -> StorageApi { 106 | let (tx, rx) = tokio::sync::mpsc::channel(1); 107 | let actor = Self { 108 | recv: rx, 109 | state: BTreeMap::new(), 110 | }; 111 | n0_future::task::spawn(actor.run()); 112 | StorageApi { 113 | inner: Client::local(tx), 114 | } 115 | } 116 | 117 | async fn run(mut self) { 118 | while let Some(msg) = self.recv.recv().await { 119 | self.handle(msg).await; 120 | } 121 | } 122 | 123 | async fn handle(&mut self, msg: StorageMessage) { 124 | match msg { 125 | StorageMessage::Get(get) => { 126 | info!("get {:?}", get); 127 | let WithChannels { tx, inner, .. } = get; 128 | tx.send(self.state.get(&inner.key).cloned()).await.ok(); 129 | } 130 | StorageMessage::Set(set) => { 131 | info!("set {:?}", set); 132 | let WithChannels { tx, inner, .. } = set; 133 | self.state.insert(inner.key, inner.value); 134 | tx.send(()).await.ok(); 135 | } 136 | StorageMessage::List(list) => { 137 | info!("list {:?}", list); 138 | let WithChannels { tx, .. } = list; 139 | for (key, value) in &self.state { 140 | if tx.send(format!("{key}={value}")).await.is_err() { 141 | break; 142 | } 143 | } 144 | } 145 | } 146 | } 147 | } 148 | 149 | pub struct StorageApi { 150 | inner: Client, 151 | } 152 | 153 | impl StorageApi { 154 | pub const ALPN: &[u8] = b"irpc-iroh/derive-demo/0"; 155 | 156 | pub fn spawn() -> Self { 157 | StorageActor::spawn() 158 | } 159 | 160 | pub fn connect( 161 | endpoint: Endpoint, 162 | addr: impl Into, 163 | ) -> Result { 164 | let conn = IrohLazyRemoteConnection::new(endpoint, addr.into(), Self::ALPN.to_vec()); 165 | Ok(StorageApi { 166 | inner: Client::boxed(conn), 167 | }) 168 | } 169 | 170 | pub fn expose(&self) -> Result { 171 | let local = self 172 | .inner 173 | .as_local() 174 | .context("can not listen on remote service")?; 175 | Ok(IrohProtocol::new(StorageProtocol::remote_handler(local))) 176 | } 177 | 178 | pub async fn get(&self, key: String) -> irpc::Result> { 179 | self.inner.rpc(Get { key }).await 180 | } 181 | 182 | pub async fn list(&self) -> irpc::Result> { 183 | self.inner.server_streaming(List, 10).await 184 | } 185 | 186 | pub async fn set(&self, key: String, value: String) -> irpc::Result<()> { 187 | let msg = Set { key, value }; 188 | self.inner.rpc(msg).await 189 | } 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /examples/derive.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::BTreeMap, 3 | net::{Ipv4Addr, SocketAddr, SocketAddrV4}, 4 | }; 5 | 6 | use anyhow::{Context, Result}; 7 | use irpc::{ 8 | channel::{mpsc, oneshot}, 9 | rpc::RemoteService, 10 | rpc_requests, 11 | util::{make_client_endpoint, make_server_endpoint}, 12 | Client, WithChannels, 13 | }; 14 | // Import the macro 15 | use n0_future::task::{self, AbortOnDropHandle}; 16 | use serde::{Deserialize, Serialize}; 17 | use tracing::info; 18 | 19 | #[derive(Debug, Serialize, Deserialize)] 20 | struct Get { 21 | key: String, 22 | } 23 | 24 | #[derive(Debug, Serialize, Deserialize)] 25 | struct List; 26 | 27 | #[derive(Debug, Serialize, Deserialize)] 28 | struct Set { 29 | key: String, 30 | value: String, 31 | } 32 | 33 | impl From<(String, String)> for Set { 34 | fn from((key, value): (String, String)) -> Self { 35 | Self { key, value } 36 | } 37 | } 38 | 39 | #[derive(Debug, Serialize, Deserialize)] 40 | struct SetMany; 41 | 42 | // Use the macro to generate both the StorageProtocol and StorageMessage enums 43 | // plus implement Channels for each type 44 | #[rpc_requests(message = StorageMessage)] 45 | #[derive(Serialize, Deserialize, Debug)] 46 | enum StorageProtocol { 47 | #[rpc(tx=oneshot::Sender>)] 48 | Get(Get), 49 | #[rpc(tx=oneshot::Sender<()>)] 50 | Set(Set), 51 | #[rpc(tx=oneshot::Sender, rx=mpsc::Receiver<(String, String)>)] 52 | SetMany(SetMany), 53 | #[rpc(tx=mpsc::Sender)] 54 | List(List), 55 | } 56 | 57 | struct StorageActor { 58 | recv: tokio::sync::mpsc::Receiver, 59 | state: BTreeMap, 60 | } 61 | 62 | impl StorageActor { 63 | pub fn spawn() -> StorageApi { 64 | let (tx, rx) = tokio::sync::mpsc::channel(1); 65 | let actor = Self { 66 | recv: rx, 67 | state: BTreeMap::new(), 68 | }; 69 | n0_future::task::spawn(actor.run()); 70 | StorageApi { 71 | inner: Client::local(tx), 72 | } 73 | } 74 | 75 | async fn run(mut self) { 76 | while let Some(msg) = self.recv.recv().await { 77 | self.handle(msg).await; 78 | } 79 | } 80 | 81 | async fn handle(&mut self, msg: StorageMessage) { 82 | match msg { 83 | StorageMessage::Get(get) => { 84 | info!("get {:?}", get); 85 | let WithChannels { tx, inner, .. } = get; 86 | tx.send(self.state.get(&inner.key).cloned()).await.ok(); 87 | } 88 | StorageMessage::Set(set) => { 89 | info!("set {:?}", set); 90 | let WithChannels { tx, inner, .. } = set; 91 | self.state.insert(inner.key, inner.value); 92 | tx.send(()).await.ok(); 93 | } 94 | StorageMessage::SetMany(set) => { 95 | info!("set-many {:?}", set); 96 | let WithChannels { mut rx, tx, .. } = set; 97 | let mut count = 0; 98 | while let Ok(Some((key, value))) = rx.recv().await { 99 | self.state.insert(key, value); 100 | count += 1; 101 | } 102 | tx.send(count).await.ok(); 103 | } 104 | StorageMessage::List(list) => { 105 | info!("list {:?}", list); 106 | let WithChannels { tx, .. } = list; 107 | for (key, value) in &self.state { 108 | if tx.send(format!("{key}={value}")).await.is_err() { 109 | break; 110 | } 111 | } 112 | } 113 | } 114 | } 115 | } 116 | 117 | struct StorageApi { 118 | inner: Client, 119 | } 120 | 121 | impl StorageApi { 122 | pub fn connect(endpoint: quinn::Endpoint, addr: SocketAddr) -> Result { 123 | Ok(StorageApi { 124 | inner: Client::quinn(endpoint, addr), 125 | }) 126 | } 127 | 128 | pub fn listen(&self, endpoint: quinn::Endpoint) -> Result> { 129 | let local = self 130 | .inner 131 | .as_local() 132 | .context("cannot listen on remote API")?; 133 | let join_handle = task::spawn(irpc::rpc::listen( 134 | endpoint, 135 | StorageProtocol::remote_handler(local), 136 | )); 137 | Ok(AbortOnDropHandle::new(join_handle)) 138 | } 139 | 140 | pub async fn get(&self, key: String) -> irpc::Result> { 141 | self.inner.rpc(Get { key }).await 142 | } 143 | 144 | pub async fn list(&self) -> irpc::Result> { 145 | self.inner.server_streaming(List, 16).await 146 | } 147 | 148 | pub async fn set(&self, key: String, value: String) -> irpc::Result<()> { 149 | self.inner.rpc(Set { key, value }).await 150 | } 151 | 152 | pub async fn set_many( 153 | &self, 154 | ) -> irpc::Result<(mpsc::Sender<(String, String)>, oneshot::Receiver)> { 155 | self.inner.client_streaming(SetMany, 4).await 156 | } 157 | } 158 | 159 | async fn client_demo(api: StorageApi) -> Result<()> { 160 | api.set("hello".to_string(), "world".to_string()).await?; 161 | let value = api.get("hello".to_string()).await?; 162 | println!("get: hello = {value:?}"); 163 | 164 | let (tx, rx) = api.set_many().await?; 165 | for i in 0..3 { 166 | tx.send((format!("key{i}"), format!("value{i}"))).await?; 167 | } 168 | drop(tx); 169 | let count = rx.await?; 170 | println!("set-many: {count} values set"); 171 | 172 | let mut list = api.list().await?; 173 | while let Some(value) = list.recv().await? { 174 | println!("list value = {value:?}"); 175 | } 176 | Ok(()) 177 | } 178 | 179 | async fn local() -> Result<()> { 180 | let api = StorageActor::spawn(); 181 | client_demo(api).await?; 182 | Ok(()) 183 | } 184 | 185 | async fn remote() -> Result<()> { 186 | let port = 10113; 187 | let addr: SocketAddr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); 188 | 189 | let (server_handle, cert) = { 190 | let (endpoint, cert) = make_server_endpoint(addr)?; 191 | let api = StorageActor::spawn(); 192 | let handle = api.listen(endpoint)?; 193 | (handle, cert) 194 | }; 195 | 196 | let endpoint = 197 | make_client_endpoint(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(), &[&cert])?; 198 | let api = StorageApi::connect(endpoint, addr)?; 199 | client_demo(api).await?; 200 | 201 | drop(server_handle); 202 | Ok(()) 203 | } 204 | 205 | #[tokio::main] 206 | async fn main() -> Result<()> { 207 | tracing_subscriber::fmt::init(); 208 | println!("Local use"); 209 | local().await?; 210 | println!("Remote use"); 211 | remote().await.unwrap(); 212 | Ok(()) 213 | } 214 | -------------------------------------------------------------------------------- /examples/storage.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::BTreeMap, 3 | net::{Ipv4Addr, SocketAddr, SocketAddrV4}, 4 | }; 5 | 6 | use anyhow::bail; 7 | use irpc::{ 8 | channel::{mpsc, none::NoReceiver, oneshot}, 9 | rpc::{listen, RemoteService}, 10 | util::{make_client_endpoint, make_server_endpoint}, 11 | Channels, Client, Request, Service, WithChannels, 12 | }; 13 | use n0_future::task::{self, AbortOnDropHandle}; 14 | use serde::{Deserialize, Serialize}; 15 | use tracing::info; 16 | 17 | impl Service for StorageProtocol { 18 | type Message = StorageMessage; 19 | } 20 | 21 | #[derive(Debug, Serialize, Deserialize)] 22 | struct Get { 23 | key: String, 24 | } 25 | 26 | impl Channels for Get { 27 | type Rx = NoReceiver; 28 | type Tx = oneshot::Sender>; 29 | } 30 | 31 | #[derive(Debug, Serialize, Deserialize)] 32 | struct List; 33 | 34 | impl Channels for List { 35 | type Rx = NoReceiver; 36 | type Tx = mpsc::Sender; 37 | } 38 | 39 | #[derive(Debug, Serialize, Deserialize)] 40 | struct Set { 41 | key: String, 42 | value: String, 43 | } 44 | 45 | impl Channels for Set { 46 | type Rx = NoReceiver; 47 | type Tx = oneshot::Sender<()>; 48 | } 49 | 50 | #[derive(derive_more::From, Serialize, Deserialize, Debug)] 51 | enum StorageProtocol { 52 | Get(Get), 53 | Set(Set), 54 | List(List), 55 | } 56 | 57 | #[derive(derive_more::From)] 58 | enum StorageMessage { 59 | Get(WithChannels), 60 | Set(WithChannels), 61 | List(WithChannels), 62 | } 63 | 64 | impl RemoteService for StorageProtocol { 65 | fn with_remote_channels(self, rx: quinn::RecvStream, tx: quinn::SendStream) -> Self::Message { 66 | match self { 67 | StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(), 68 | StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(), 69 | StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(), 70 | } 71 | } 72 | } 73 | 74 | struct StorageActor { 75 | recv: tokio::sync::mpsc::Receiver, 76 | state: BTreeMap, 77 | } 78 | 79 | impl StorageActor { 80 | pub fn local() -> StorageApi { 81 | let (tx, rx) = tokio::sync::mpsc::channel(1); 82 | let actor = Self { 83 | recv: rx, 84 | state: BTreeMap::new(), 85 | }; 86 | n0_future::task::spawn(actor.run()); 87 | StorageApi { 88 | inner: Client::local(tx), 89 | } 90 | } 91 | 92 | async fn run(mut self) { 93 | while let Some(msg) = self.recv.recv().await { 94 | self.handle(msg).await; 95 | } 96 | } 97 | 98 | async fn handle(&mut self, msg: StorageMessage) { 99 | match msg { 100 | StorageMessage::Get(get) => { 101 | info!("get {:?}", get); 102 | let WithChannels { tx, inner, .. } = get; 103 | tx.send(self.state.get(&inner.key).cloned()).await.ok(); 104 | } 105 | StorageMessage::Set(set) => { 106 | info!("set {:?}", set); 107 | let WithChannels { tx, inner, .. } = set; 108 | self.state.insert(inner.key, inner.value); 109 | tx.send(()).await.ok(); 110 | } 111 | StorageMessage::List(list) => { 112 | info!("list {:?}", list); 113 | let WithChannels { tx, .. } = list; 114 | for (key, value) in &self.state { 115 | if tx.send(format!("{key}={value}")).await.is_err() { 116 | break; 117 | } 118 | } 119 | } 120 | } 121 | } 122 | } 123 | struct StorageApi { 124 | inner: Client, 125 | } 126 | 127 | impl StorageApi { 128 | pub fn connect(endpoint: quinn::Endpoint, addr: SocketAddr) -> anyhow::Result { 129 | Ok(StorageApi { 130 | inner: Client::quinn(endpoint, addr), 131 | }) 132 | } 133 | 134 | pub fn listen(&self, endpoint: quinn::Endpoint) -> anyhow::Result> { 135 | let Some(local) = self.inner.as_local() else { 136 | bail!("cannot listen on a remote service"); 137 | }; 138 | let handler = StorageProtocol::remote_handler(local); 139 | Ok(AbortOnDropHandle::new(task::spawn(listen( 140 | endpoint, handler, 141 | )))) 142 | } 143 | 144 | pub async fn get(&self, key: String) -> anyhow::Result>> { 145 | let msg = Get { key }; 146 | match self.inner.request().await? { 147 | Request::Local(request) => { 148 | let (tx, rx) = oneshot::channel(); 149 | request.send((msg, tx)).await?; 150 | Ok(rx) 151 | } 152 | Request::Remote(request) => { 153 | let (_tx, rx) = request.write(msg).await?; 154 | Ok(rx.into()) 155 | } 156 | } 157 | } 158 | 159 | pub async fn list(&self) -> anyhow::Result> { 160 | let msg = List; 161 | match self.inner.request().await? { 162 | Request::Local(request) => { 163 | let (tx, rx) = mpsc::channel(10); 164 | request.send((msg, tx)).await?; 165 | Ok(rx) 166 | } 167 | Request::Remote(request) => { 168 | let (_tx, rx) = request.write(msg).await?; 169 | Ok(rx.into()) 170 | } 171 | } 172 | } 173 | 174 | pub async fn set(&self, key: String, value: String) -> anyhow::Result> { 175 | let msg = Set { key, value }; 176 | match self.inner.request().await? { 177 | Request::Local(request) => { 178 | let (tx, rx) = oneshot::channel(); 179 | request.send((msg, tx)).await?; 180 | Ok(rx) 181 | } 182 | Request::Remote(request) => { 183 | let (_tx, rx) = request.write(msg).await?; 184 | Ok(rx.into()) 185 | } 186 | } 187 | } 188 | } 189 | 190 | async fn local() -> anyhow::Result<()> { 191 | let api = StorageActor::local(); 192 | api.set("hello".to_string(), "world".to_string()) 193 | .await? 194 | .await?; 195 | let value = api.get("hello".to_string()).await?.await?; 196 | let mut list = api.list().await?; 197 | while let Some(value) = list.recv().await? { 198 | println!("list value = {value:?}"); 199 | } 200 | println!("value = {value:?}"); 201 | Ok(()) 202 | } 203 | 204 | async fn remote() -> anyhow::Result<()> { 205 | let port = 10113; 206 | let (server, cert) = 207 | make_server_endpoint(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into())?; 208 | let client = 209 | make_client_endpoint(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(), &[&cert])?; 210 | let store = StorageActor::local(); 211 | let handle = store.listen(server)?; 212 | let api = StorageApi::connect(client, SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into())?; 213 | api.set("hello".to_string(), "world".to_string()) 214 | .await? 215 | .await?; 216 | api.set("goodbye".to_string(), "world".to_string()) 217 | .await? 218 | .await?; 219 | let value = api.get("hello".to_string()).await?.await?; 220 | println!("value = {value:?}"); 221 | let mut list = api.list().await?; 222 | while let Some(value) = list.recv().await? { 223 | println!("list value = {value:?}"); 224 | } 225 | drop(handle); 226 | Ok(()) 227 | } 228 | 229 | #[tokio::main] 230 | async fn main() -> anyhow::Result<()> { 231 | tracing_subscriber::fmt::init(); 232 | println!("Local use"); 233 | local().await?; 234 | println!("Remote use"); 235 | remote().await?; 236 | Ok(()) 237 | } 238 | -------------------------------------------------------------------------------- /irpc-iroh/examples/auth.rs: -------------------------------------------------------------------------------- 1 | //! This example demonstrates a few things: 2 | //! * Using irpc with a cloneable server struct instead of with an actor loop 3 | //! * Manually implementing the connection loop 4 | //! * Authenticating peers 5 | 6 | use anyhow::Result; 7 | use iroh::{protocol::Router, Endpoint}; 8 | 9 | use self::storage::{StorageClient, StorageServer}; 10 | 11 | #[tokio::main] 12 | async fn main() -> Result<()> { 13 | tracing_subscriber::fmt::init(); 14 | println!("Remote use"); 15 | remote().await?; 16 | Ok(()) 17 | } 18 | 19 | async fn remote() -> Result<()> { 20 | let (server_router, server_addr) = { 21 | let endpoint = Endpoint::bind().await?; 22 | let server = StorageServer::new("secret".to_string()); 23 | let router = Router::builder(endpoint.clone()) 24 | .accept(StorageServer::ALPN, server.clone()) 25 | .spawn(); 26 | let addr = endpoint.addr(); 27 | (router, addr) 28 | }; 29 | 30 | // correct authentication 31 | let client_endpoint = Endpoint::builder().bind().await?; 32 | let api = StorageClient::connect(client_endpoint, server_addr.clone()); 33 | api.auth("secret").await?; 34 | api.set("hello".to_string(), "world".to_string()).await?; 35 | api.set("goodbye".to_string(), "world".to_string()).await?; 36 | let value = api.get("hello".to_string()).await?; 37 | println!("value = {value:?}"); 38 | let mut list = api.list().await?; 39 | while let Some(value) = list.recv().await? { 40 | println!("list value = {value:?}"); 41 | } 42 | 43 | // invalid authentication 44 | let client_endpoint = Endpoint::builder().bind().await?; 45 | let api = StorageClient::connect(client_endpoint, server_addr.clone()); 46 | assert!(api.auth("bad").await.is_err()); 47 | assert!(api.get("hello".to_string()).await.is_err()); 48 | 49 | // no authentication 50 | let client_endpoint = Endpoint::builder().bind().await?; 51 | let api = StorageClient::connect(client_endpoint, server_addr); 52 | assert!(api.get("hello".to_string()).await.is_err()); 53 | 54 | drop(server_router); 55 | Ok(()) 56 | } 57 | 58 | mod storage { 59 | //! Implementation of our storage service. 60 | //! 61 | //! The only `pub` item is [`StorageApi`], everything else is private. 62 | 63 | use std::{ 64 | collections::BTreeMap, 65 | sync::{Arc, Mutex}, 66 | }; 67 | 68 | use anyhow::Result; 69 | use iroh::{ 70 | endpoint::Connection, 71 | protocol::{AcceptError, ProtocolHandler}, 72 | Endpoint, 73 | }; 74 | use irpc::{ 75 | channel::{mpsc, oneshot}, 76 | rpc_requests, Client, WithChannels, 77 | }; 78 | // Import the macro 79 | use irpc_iroh::{read_request, IrohLazyRemoteConnection}; 80 | use serde::{Deserialize, Serialize}; 81 | use tracing::info; 82 | 83 | const ALPN: &[u8] = b"storage-api/0"; 84 | 85 | #[derive(Debug, Serialize, Deserialize)] 86 | struct Auth { 87 | token: String, 88 | } 89 | 90 | #[derive(Debug, Serialize, Deserialize)] 91 | struct Get { 92 | key: String, 93 | } 94 | 95 | #[derive(Debug, Serialize, Deserialize)] 96 | struct List; 97 | 98 | #[derive(Debug, Serialize, Deserialize)] 99 | struct Set { 100 | key: String, 101 | value: String, 102 | } 103 | 104 | #[derive(Debug, Serialize, Deserialize)] 105 | struct SetMany; 106 | 107 | // Use the macro to generate both the StorageProtocol and StorageMessage enums 108 | // plus implement Channels for each type 109 | #[rpc_requests(message = StorageMessage)] 110 | #[derive(Serialize, Deserialize, Debug)] 111 | enum StorageProtocol { 112 | #[rpc(tx=oneshot::Sender>)] 113 | Auth(Auth), 114 | #[rpc(tx=oneshot::Sender>)] 115 | Get(Get), 116 | #[rpc(tx=oneshot::Sender<()>)] 117 | Set(Set), 118 | #[rpc(tx=oneshot::Sender, rx=mpsc::Receiver<(String, String)>)] 119 | SetMany(SetMany), 120 | #[rpc(tx=mpsc::Sender)] 121 | List(List), 122 | } 123 | 124 | #[derive(Debug, Clone)] 125 | pub struct StorageServer { 126 | state: Arc>>, 127 | auth_token: String, 128 | } 129 | 130 | impl ProtocolHandler for StorageServer { 131 | async fn accept(&self, conn: Connection) -> Result<(), AcceptError> { 132 | let mut authed = false; 133 | while let Some(msg) = read_request::(&conn).await? { 134 | match msg { 135 | StorageMessage::Auth(msg) => { 136 | let WithChannels { inner, tx, .. } = msg; 137 | if authed { 138 | conn.close(1u32.into(), b"invalid message"); 139 | break; 140 | } else if inner.token != self.auth_token { 141 | conn.close(1u32.into(), b"permission denied"); 142 | break; 143 | } else { 144 | authed = true; 145 | tx.send(Ok(())).await.ok(); 146 | } 147 | } 148 | msg => { 149 | if !authed { 150 | conn.close(1u32.into(), b"permission denied"); 151 | break; 152 | } else { 153 | self.handle_authenticated(msg).await; 154 | } 155 | } 156 | } 157 | } 158 | conn.closed().await; 159 | Ok(()) 160 | } 161 | } 162 | 163 | impl StorageServer { 164 | pub const ALPN: &[u8] = ALPN; 165 | 166 | pub fn new(auth_token: String) -> Self { 167 | Self { 168 | state: Default::default(), 169 | auth_token, 170 | } 171 | } 172 | 173 | async fn handle_authenticated(&self, msg: StorageMessage) { 174 | match msg { 175 | StorageMessage::Auth(_) => unreachable!("handled in ProtocolHandler::accept"), 176 | StorageMessage::Get(get) => { 177 | info!("get {:?}", get); 178 | let WithChannels { tx, inner, .. } = get; 179 | let res = self.state.lock().unwrap().get(&inner.key).cloned(); 180 | tx.send(res).await.ok(); 181 | } 182 | StorageMessage::Set(set) => { 183 | info!("set {:?}", set); 184 | let WithChannels { tx, inner, .. } = set; 185 | self.state.lock().unwrap().insert(inner.key, inner.value); 186 | tx.send(()).await.ok(); 187 | } 188 | StorageMessage::SetMany(list) => { 189 | let WithChannels { tx, mut rx, .. } = list; 190 | let mut i = 0; 191 | while let Ok(Some((key, value))) = rx.recv().await { 192 | let mut state = self.state.lock().unwrap(); 193 | state.insert(key, value); 194 | i += 1; 195 | } 196 | tx.send(i).await.ok(); 197 | } 198 | StorageMessage::List(list) => { 199 | info!("list {:?}", list); 200 | let WithChannels { tx, .. } = list; 201 | let values = { 202 | let state = self.state.lock().unwrap(); 203 | // TODO: use async lock to not clone here. 204 | let values: Vec<_> = state 205 | .iter() 206 | .map(|(key, value)| format!("{key}={value}")) 207 | .collect(); 208 | values 209 | }; 210 | for value in values { 211 | if tx.send(value).await.is_err() { 212 | break; 213 | } 214 | } 215 | } 216 | } 217 | } 218 | } 219 | 220 | pub struct StorageClient { 221 | inner: Client, 222 | } 223 | 224 | impl StorageClient { 225 | pub const ALPN: &[u8] = ALPN; 226 | 227 | pub fn connect(endpoint: Endpoint, addr: impl Into) -> StorageClient { 228 | let conn = IrohLazyRemoteConnection::new(endpoint, addr.into(), Self::ALPN.to_vec()); 229 | StorageClient { 230 | inner: Client::boxed(conn), 231 | } 232 | } 233 | 234 | pub async fn auth(&self, token: &str) -> Result<(), anyhow::Error> { 235 | self.inner 236 | .rpc(Auth { 237 | token: token.to_string(), 238 | }) 239 | .await? 240 | .map_err(|err| anyhow::anyhow!(err)) 241 | } 242 | 243 | pub async fn get(&self, key: String) -> Result, irpc::Error> { 244 | self.inner.rpc(Get { key }).await 245 | } 246 | 247 | pub async fn list(&self) -> Result, irpc::Error> { 248 | self.inner.server_streaming(List, 10).await 249 | } 250 | 251 | pub async fn set(&self, key: String, value: String) -> Result<(), irpc::Error> { 252 | let msg = Set { key, value }; 253 | self.inner.rpc(msg).await 254 | } 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /tests/mpsc_channel.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "quinn_endpoint_setup")] 2 | 3 | use std::{ 4 | io::{self, ErrorKind}, 5 | time::Duration, 6 | }; 7 | 8 | use irpc::{ 9 | channel::{ 10 | mpsc::{self, Receiver, RecvError}, 11 | SendError, 12 | }, 13 | util::AsyncWriteVarintExt, 14 | }; 15 | use n0_error::e; 16 | use quinn::Endpoint; 17 | use testresult::TestResult; 18 | use tokio::time::timeout; 19 | 20 | mod common; 21 | use common::*; 22 | 23 | /// Checks that all clones of a `Sender` will get the closed signal as soon as 24 | /// a send fails with an io error. 25 | #[tokio::test] 26 | async fn mpsc_sender_clone_closed_error() -> TestResult<()> { 27 | tracing_subscriber::fmt::try_init().ok(); 28 | let (server, client, server_addr) = create_connected_endpoints()?; 29 | // accept a single bidi stream on a single connection, then immediately stop it 30 | let server = tokio::spawn(async move { 31 | let conn = server.accept().await.unwrap().await?; 32 | let (_, mut recv) = conn.accept_bi().await?; 33 | recv.stop(1u8.into())?; 34 | TestResult::Ok(()) 35 | }); 36 | let conn = client.connect(server_addr, "localhost")?.await?; 37 | let (send, _) = conn.open_bi().await?; 38 | let send1 = mpsc::Sender::>::from(send); 39 | let send2 = send1.clone(); 40 | let send3 = send1.clone(); 41 | let second_client = tokio::spawn(async move { 42 | send2.closed().await; 43 | }); 44 | let third_client = tokio::spawn(async move { 45 | // this should fail with an io error, since the stream was stopped 46 | loop { 47 | match send3.send(vec![1, 2, 3]).await { 48 | Err(SendError::Io { source, .. }) if source.kind() == ErrorKind::BrokenPipe => { 49 | break 50 | } 51 | _ => {} 52 | }; 53 | } 54 | }); 55 | // send until we get an error because the remote side stopped the stream 56 | while send1.send(vec![1, 2, 3]).await.is_ok() {} 57 | match send1.send(vec![4, 5, 6]).await { 58 | Err(SendError::Io { source, .. }) if source.kind() == ErrorKind::BrokenPipe => {} 59 | e => panic!("Expected SendError::Io with kind BrokenPipe, got {e:?}"), 60 | }; 61 | // check that closed signal was received by the second sender 62 | second_client.await?; 63 | // check that the third sender will get the right kind of io error eventually 64 | third_client.await?; 65 | // server should finish without errors 66 | server.await??; 67 | Ok(()) 68 | } 69 | 70 | /// Checks that all clones of a `Sender` will get the closed signal as soon as 71 | /// a send future gets dropped before completing. 72 | #[tokio::test] 73 | async fn mpsc_sender_clone_drop_error() -> TestResult<()> { 74 | let (server, client, server_addr) = create_connected_endpoints()?; 75 | // accept a single bidi stream on a single connection, then read indefinitely 76 | // until we get an error or the stream is finished 77 | let server = tokio::spawn(async move { 78 | let conn = server.accept().await.unwrap().await?; 79 | let (_, mut recv) = conn.accept_bi().await?; 80 | let mut buf = vec![0u8; 1024]; 81 | while let Ok(Some(_)) = recv.read(&mut buf).await {} 82 | TestResult::Ok(()) 83 | }); 84 | let conn = client.connect(server_addr, "localhost")?.await?; 85 | let (send, _) = conn.open_bi().await?; 86 | let send1 = mpsc::Sender::>::from(send); 87 | let send2 = send1.clone(); 88 | let send3 = send1.clone(); 89 | let second_client = tokio::spawn(async move { 90 | send2.closed().await; 91 | }); 92 | let third_client = tokio::spawn(async move { 93 | // this should fail with an io error, since the stream was stopped 94 | loop { 95 | match send3.send(vec![1, 2, 3]).await { 96 | Err(SendError::Io { source, .. }) if source.kind() == ErrorKind::BrokenPipe => { 97 | break 98 | } 99 | _ => {} 100 | }; 101 | } 102 | }); 103 | // send a lot of data with a tiny timeout, this will cause the send future to be dropped 104 | loop { 105 | let send_future = send1.send(vec![0u8; 1024 * 1024]); 106 | // not sure if there is a better way. I want to poll the future a few times so it has time to 107 | // start sending, but don't want to give it enough time to complete. 108 | // I don't think now_or_never would work, since it wouldn't have time to start sending 109 | if timeout(Duration::from_micros(1), send_future) 110 | .await 111 | .is_err() 112 | { 113 | break; 114 | } 115 | } 116 | server.await??; 117 | second_client.await?; 118 | third_client.await?; 119 | Ok(()) 120 | } 121 | 122 | async fn vec_receiver(server: Endpoint) -> Result<(), RecvError> { 123 | let conn = server 124 | .accept() 125 | .await 126 | .unwrap() 127 | .await 128 | .map_err(|err| e!(RecvError::Io, err.into()))?; 129 | let (_, recv) = conn 130 | .accept_bi() 131 | .await 132 | .map_err(|err| e!(RecvError::Io, err.into()))?; 133 | let mut recv = Receiver::>::from(recv); 134 | while recv.recv().await?.is_some() {} 135 | Err(e!(RecvError::Io, io::ErrorKind::UnexpectedEof.into())) 136 | } 137 | 138 | /// Checks that the max message size is enforced on the sender side and that errors are propagated to the receiver side. 139 | #[tokio::test] 140 | async fn mpsc_max_message_size_send() -> TestResult<()> { 141 | let (server, client, server_addr) = create_connected_endpoints()?; 142 | let server = tokio::spawn(vec_receiver(server)); 143 | let conn = client.connect(server_addr, "localhost")?.await?; 144 | let (send, _) = conn.open_bi().await?; 145 | let send = mpsc::Sender::>::from(send); 146 | // this one should work! 147 | send.send(vec![0u8; 1024 * 1024]).await?; 148 | // this one should fail! 149 | let Err(cause) = send.send(vec![0u8; 1024 * 1024 * 32]).await else { 150 | panic!("client should have failed due to max message size"); 151 | }; 152 | assert!(matches!(cause, SendError::MaxMessageSizeExceeded { .. })); 153 | let Err(cause) = server.await? else { 154 | panic!("server should have failed due to max message size"); 155 | }; 156 | assert!( 157 | matches!(cause, mpsc::RecvError::Io { source, .. } if source.kind() == ErrorKind::ConnectionReset) 158 | ); 159 | Ok(()) 160 | } 161 | 162 | /// Checks that the max message size is enforced on receiver side. 163 | #[tokio::test] 164 | async fn mpsc_max_message_size_recv() -> TestResult<()> { 165 | let (server, client, server_addr) = create_connected_endpoints()?; 166 | let server = tokio::spawn(vec_receiver(server)); 167 | let conn = client.connect(server_addr, "localhost")?.await?; 168 | let (mut send, _) = conn.open_bi().await?; 169 | // this one should work! 170 | send.write_length_prefixed(vec![0u8; 1024 * 1024]).await?; 171 | // this one should fail on receive! 172 | send.write_length_prefixed(vec![0u8; 1024 * 1024 * 32]) 173 | .await 174 | .ok(); 175 | let Err(cause) = server.await? else { 176 | panic!("server should have failed due to max message size"); 177 | }; 178 | assert!(matches!( 179 | cause, 180 | mpsc::RecvError::MaxMessageSizeExceeded { .. } 181 | )); 182 | Ok(()) 183 | } 184 | 185 | async fn noser_receiver(server: Endpoint) -> Result<(), mpsc::RecvError> { 186 | let conn = server 187 | .accept() 188 | .await 189 | .unwrap() 190 | .await 191 | .map_err(|err| e!(mpsc::RecvError::Io, err.into()))?; 192 | let (_, recv) = conn 193 | .accept_bi() 194 | .await 195 | .map_err(|err| e!(mpsc::RecvError::Io, err.into()))?; 196 | let mut recv = mpsc::Receiver::::from(recv); 197 | while recv.recv().await?.is_some() {} 198 | Err(e!(mpsc::RecvError::Io, io::ErrorKind::UnexpectedEof.into())) 199 | } 200 | 201 | /// Checks that a serialization error is caught and propagated to the receiver. 202 | #[tokio::test] 203 | async fn mpsc_serialize_error_send() -> TestResult<()> { 204 | let (server, client, server_addr) = create_connected_endpoints()?; 205 | let server = tokio::spawn(noser_receiver(server)); 206 | let conn = client.connect(server_addr, "localhost")?.await?; 207 | let (send, _) = conn.open_bi().await?; 208 | let send = mpsc::Sender::::from(send); 209 | // this one should work! 210 | send.send(NoSer(0)).await?; 211 | // this one should fail! 212 | let Err(cause) = send.send(NoSer(1)).await else { 213 | panic!("client should have failed due to serialization error"); 214 | }; 215 | assert!( 216 | matches!(cause, SendError::Io { source, .. } if source.kind() == ErrorKind::InvalidData) 217 | ); 218 | let Err(cause) = server.await? else { 219 | panic!("server should have failed due to serialization error"); 220 | }; 221 | assert!( 222 | matches!(cause, mpsc::RecvError::Io { source, .. } if source.kind() == ErrorKind::ConnectionReset) 223 | ); 224 | Ok(()) 225 | } 226 | 227 | #[tokio::test] 228 | async fn mpsc_serialize_error_recv() -> TestResult<()> { 229 | let (server, client, server_addr) = create_connected_endpoints()?; 230 | let server = tokio::spawn(noser_receiver(server)); 231 | let conn = client.connect(server_addr, "localhost")?.await?; 232 | let (mut send, _) = conn.open_bi().await?; 233 | // this one should work! 234 | send.write_length_prefixed(0u64).await?; 235 | // this one should fail on receive! 236 | send.write_length_prefixed(1u64).await.ok(); 237 | let Err(cause) = server.await? else { 238 | panic!("server should have failed due to serialization error"); 239 | }; 240 | assert!( 241 | matches!(cause, mpsc::RecvError::Io { source, .. } if source.kind() == ErrorKind::InvalidData) 242 | ); 243 | Ok(()) 244 | } 245 | -------------------------------------------------------------------------------- /irpc-iroh/examples/0rtt.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | env, 3 | str::FromStr, 4 | time::{Duration, Instant}, 5 | }; 6 | 7 | use anyhow::{Context, Result}; 8 | use clap::Parser; 9 | use iroh::{protocol::Router, Endpoint, EndpointAddr, EndpointId, SecretKey}; 10 | use ping::EchoApi; 11 | 12 | #[tokio::main] 13 | async fn main() -> Result<()> { 14 | tracing_subscriber::fmt::init(); 15 | let args = cli::Args::parse(); 16 | match args { 17 | cli::Args::Listen { no_0rtt } => { 18 | let (server_router, server_addr) = { 19 | let secret_key = get_or_generate_secret_key()?; 20 | let endpoint = Endpoint::builder().secret_key(secret_key).bind().await?; 21 | endpoint.online().await; 22 | let addr = endpoint.addr(); 23 | let api = EchoApi::spawn(); 24 | let router = Router::builder(endpoint.clone()); 25 | let router = if !no_0rtt { 26 | router.accept(EchoApi::ALPN, api.expose_0rtt()?) 27 | } else { 28 | router.accept(EchoApi::ALPN, api.expose()?) 29 | }; 30 | let router = router.spawn(); 31 | (router, addr) 32 | }; 33 | println!("EndpointId: {}", server_addr.id); 34 | println!("Accepting 0rtt connections: {}", !no_0rtt); 35 | let ticket = server_addr.id.to_string(); 36 | println!("Connect using:\n\ncargo run --example 0rtt connect {ticket}\n"); 37 | println!("Control-C to stop"); 38 | tokio::signal::ctrl_c() 39 | .await 40 | .expect("failed to listen for ctrl_c"); 41 | server_router.shutdown().await?; 42 | } 43 | cli::Args::Connect { 44 | ticket, 45 | n, 46 | delay_ms, 47 | no_0rtt, 48 | wait_for_ticket, 49 | } => { 50 | if !no_0rtt && !wait_for_ticket { 51 | eprintln!("0-RTT is enabled but wait_for_ticket is not set. After 2 requests with 0rtt the 0rtt resumption tickets will be consumed and a connection will be done without 0rtt."); 52 | } 53 | let n = n 54 | .iter() 55 | .filter_map(|x| u64::try_from(*x).ok()) 56 | .next() 57 | .unwrap_or(u64::MAX); 58 | let delay = std::time::Duration::from_millis(delay_ms); 59 | let endpoint = Endpoint::builder().bind().await?; 60 | let addr: EndpointAddr = ticket.into(); 61 | for i in 0..n { 62 | if let Err(e) = ping_one(no_0rtt, &endpoint, &addr, i, wait_for_ticket).await { 63 | eprintln!("Error pinging {}: {e}", addr.id); 64 | } 65 | tokio::time::sleep(delay).await; 66 | } 67 | } 68 | } 69 | Ok(()) 70 | } 71 | 72 | async fn ping_one_0rtt( 73 | api: EchoApi, 74 | endpoint: &Endpoint, 75 | endpoint_id: EndpointId, 76 | wait_for_ticket: bool, 77 | i: u64, 78 | t0: Instant, 79 | ) -> Result<()> { 80 | let msg = i.to_be_bytes(); 81 | let data = api.echo_0rtt(msg.to_vec()).await?; 82 | let latency = endpoint.latency(endpoint_id); 83 | if wait_for_ticket { 84 | tokio::spawn(async move { 85 | let latency = latency.unwrap_or(Duration::from_millis(500)); 86 | tokio::time::sleep(latency * 2).await; 87 | drop(api); 88 | }); 89 | } else { 90 | drop(api); 91 | } 92 | let elapsed = t0.elapsed(); 93 | assert!(data == msg); 94 | println!( 95 | "latency: {}", 96 | latency 97 | .map(|x| format!("{}ms", x.as_micros() as f64 / 1000.0)) 98 | .unwrap_or("unknown".into()) 99 | ); 100 | println!("ping: {}ms\n", elapsed.as_micros() as f64 / 1000.0); 101 | Ok(()) 102 | } 103 | 104 | async fn ping_one_no_0rtt( 105 | api: EchoApi, 106 | endpoint: &Endpoint, 107 | endpoint_id: EndpointId, 108 | i: u64, 109 | t0: Instant, 110 | ) -> Result<()> { 111 | let msg = i.to_be_bytes(); 112 | let data = api.echo(msg.to_vec()).await?; 113 | let latency = endpoint.latency(endpoint_id); 114 | drop(api); 115 | let elapsed = t0.elapsed(); 116 | assert!(data == msg); 117 | println!( 118 | "latency: {}", 119 | latency 120 | .map(|x| format!("{}ms", x.as_micros() as f64 / 1000.0)) 121 | .unwrap_or("unknown".into()) 122 | ); 123 | println!("ping: {}ms\n", elapsed.as_micros() as f64 / 1000.0); 124 | Ok(()) 125 | } 126 | 127 | async fn ping_one( 128 | no_0rtt: bool, 129 | endpoint: &Endpoint, 130 | addr: &EndpointAddr, 131 | i: u64, 132 | wait_for_ticket: bool, 133 | ) -> Result<()> { 134 | let endpoint_id = addr.id; 135 | let t0 = Instant::now(); 136 | if !no_0rtt { 137 | let api = EchoApi::connect_0rtt(endpoint.clone(), addr.clone()).await?; 138 | ping_one_0rtt(api, endpoint, endpoint_id, wait_for_ticket, i, t0).await?; 139 | } else { 140 | let api = EchoApi::connect(endpoint.clone(), addr.clone()).await?; 141 | ping_one_no_0rtt(api, endpoint, endpoint_id, i, t0).await?; 142 | } 143 | Ok(()) 144 | } 145 | 146 | /// Gets a secret key from the IROH_SECRET environment variable or generates a new random one. 147 | /// If the environment variable is set, it must be a valid string representation of a secret key. 148 | pub fn get_or_generate_secret_key() -> Result { 149 | if let Ok(secret) = env::var("IROH_SECRET") { 150 | // Parse the secret key from string 151 | SecretKey::from_str(&secret).context("Invalid secret key format") 152 | } else { 153 | // Generate a new random key 154 | let secret_key = SecretKey::generate(&mut rand::rng()); 155 | println!( 156 | "Generated new secret key: {}", 157 | hex::encode(secret_key.to_bytes()) 158 | ); 159 | println!("To reuse this key, set the IROH_SECRET environment variable to this value"); 160 | Ok(secret_key) 161 | } 162 | } 163 | 164 | mod cli { 165 | use clap::Parser; 166 | use iroh::EndpointId; 167 | 168 | #[derive(Debug, Parser)] 169 | pub enum Args { 170 | Listen { 171 | #[clap(long)] 172 | no_0rtt: bool, 173 | }, 174 | Connect { 175 | ticket: EndpointId, 176 | #[clap(short)] 177 | n: Option, 178 | #[clap(long)] 179 | no_0rtt: bool, 180 | #[clap(long, default_value = "1000")] 181 | delay_ms: u64, 182 | #[clap(long, default_value = "false")] 183 | wait_for_ticket: bool, 184 | }, 185 | } 186 | } 187 | 188 | mod ping { 189 | use anyhow::{Context, Result}; 190 | use iroh::Endpoint; 191 | use irpc::{channel::oneshot, rpc::RemoteService, rpc_requests, Client, WithChannels}; 192 | use irpc_iroh::{ 193 | Iroh0RttProtocol, IrohProtocol, IrohRemoteConnection, IrohZrttRemoteConnection, 194 | }; 195 | use serde::{Deserialize, Serialize}; 196 | use tracing::info; 197 | 198 | #[rpc_requests(message = EchoMessage)] 199 | #[derive(Serialize, Deserialize, Debug)] 200 | pub enum EchoProtocol { 201 | #[rpc(tx=oneshot::Sender>)] 202 | #[wrap(Echo)] 203 | Echo { data: Vec }, 204 | } 205 | 206 | pub struct EchoApi { 207 | inner: Client, 208 | } 209 | 210 | impl EchoApi { 211 | pub const ALPN: &[u8] = b"echo"; 212 | 213 | pub async fn echo(&self, data: Vec) -> irpc::Result> { 214 | self.inner.rpc(Echo { data }).await 215 | } 216 | 217 | pub async fn echo_0rtt(&self, data: Vec) -> irpc::Result> { 218 | self.inner.rpc_0rtt(Echo { data }).await 219 | } 220 | 221 | pub fn expose_0rtt(self) -> Result> { 222 | let local = self 223 | .inner 224 | .as_local() 225 | .context("can not listen on remote service")?; 226 | Ok(Iroh0RttProtocol::new(EchoProtocol::remote_handler(local))) 227 | } 228 | 229 | pub fn expose(self) -> Result> { 230 | let local = self 231 | .inner 232 | .as_local() 233 | .context("can not listen on remote service")?; 234 | Ok(IrohProtocol::new(EchoProtocol::remote_handler(local))) 235 | } 236 | 237 | pub async fn connect( 238 | endpoint: Endpoint, 239 | addr: impl Into, 240 | ) -> Result { 241 | let conn = endpoint 242 | .connect(addr, Self::ALPN) 243 | .await 244 | .context("failed to connect to remote service")?; 245 | Ok(EchoApi { 246 | inner: Client::boxed(IrohRemoteConnection::new(conn)), 247 | }) 248 | } 249 | 250 | pub async fn connect_0rtt( 251 | endpoint: Endpoint, 252 | addr: impl Into, 253 | ) -> Result { 254 | let connecting = endpoint 255 | .connect_with_opts(addr, Self::ALPN, Default::default()) 256 | .await 257 | .context("failed to connect to remote service")?; 258 | match connecting.into_0rtt() { 259 | Ok(conn) => { 260 | info!("0-RTT possible from our side"); 261 | Ok(EchoApi { 262 | inner: Client::boxed(IrohZrttRemoteConnection::new(conn)), 263 | }) 264 | } 265 | Err(connecting) => { 266 | info!("0-RTT not possible from our side"); 267 | let conn = connecting.await?; 268 | Ok(EchoApi { 269 | inner: Client::boxed(IrohRemoteConnection::new(conn)), 270 | }) 271 | } 272 | } 273 | } 274 | 275 | pub fn spawn() -> Self { 276 | EchoActor::spawn() 277 | } 278 | } 279 | 280 | struct EchoActor { 281 | recv: tokio::sync::mpsc::Receiver, 282 | } 283 | 284 | impl EchoActor { 285 | pub fn spawn() -> EchoApi { 286 | let (tx, rx) = tokio::sync::mpsc::channel(1); 287 | let actor = Self { recv: rx }; 288 | n0_future::task::spawn(actor.run()); 289 | EchoApi { 290 | inner: Client::local(tx), 291 | } 292 | } 293 | 294 | async fn run(mut self) { 295 | while let Some(msg) = self.recv.recv().await { 296 | self.handle(msg).await; 297 | } 298 | } 299 | 300 | async fn handle(&mut self, msg: EchoMessage) { 301 | match msg { 302 | EchoMessage::Echo(msg) => { 303 | info!("{:?}", msg); 304 | let WithChannels { tx, inner, .. } = msg; 305 | tx.send(inner.data).await.ok(); 306 | } 307 | } 308 | } 309 | } 310 | } 311 | -------------------------------------------------------------------------------- /examples/compute.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | io::{self, Write}, 3 | net::{Ipv4Addr, SocketAddr, SocketAddrV4}, 4 | }; 5 | 6 | use anyhow::bail; 7 | use futures_buffered::BufferedStreamExt; 8 | use irpc::{ 9 | channel::{mpsc, oneshot}, 10 | rpc::{listen, RemoteService}, 11 | rpc_requests, 12 | util::{make_client_endpoint, make_server_endpoint}, 13 | Client, Request, WithChannels, 14 | }; 15 | use n0_future::{ 16 | stream::StreamExt, 17 | task::{self, AbortOnDropHandle}, 18 | }; 19 | use serde::{Deserialize, Serialize}; 20 | use thousands::Separable; 21 | use tracing::trace; 22 | 23 | // Define the protocol and message enums using the macro 24 | #[rpc_requests(message = ComputeMessage)] 25 | #[derive(Serialize, Deserialize, Debug)] 26 | enum ComputeProtocol { 27 | #[rpc(tx=oneshot::Sender)] 28 | Sqr(Sqr), 29 | #[rpc(rx=mpsc::Receiver, tx=oneshot::Sender)] 30 | Sum(Sum), 31 | #[rpc(tx=mpsc::Sender)] 32 | Fibonacci(Fibonacci), 33 | #[rpc(rx=mpsc::Receiver, tx=mpsc::Sender)] 34 | Multiply(Multiply), 35 | } 36 | 37 | // Define ComputeProtocol sub-messages 38 | #[derive(Debug, Serialize, Deserialize)] 39 | struct Sqr { 40 | num: u64, 41 | } 42 | 43 | #[derive(Debug, Serialize, Deserialize)] 44 | struct Sum; 45 | 46 | #[derive(Debug, Serialize, Deserialize)] 47 | struct Fibonacci { 48 | max: u64, 49 | } 50 | 51 | #[derive(Debug, Serialize, Deserialize)] 52 | struct Multiply { 53 | initial: u64, 54 | } 55 | 56 | // The actor that processes requests 57 | struct ComputeActor { 58 | recv: irpc::channel::mpsc::Receiver, 59 | } 60 | 61 | impl ComputeActor { 62 | pub fn local() -> ComputeApi { 63 | let (tx, rx) = irpc::channel::mpsc::channel(128); 64 | let actor = Self { recv: rx }; 65 | n0_future::task::spawn(actor.run()); 66 | ComputeApi { 67 | inner: Client::local(tx), 68 | } 69 | } 70 | 71 | async fn run(mut self) { 72 | while let Ok(Some(msg)) = self.recv.recv().await { 73 | n0_future::task::spawn(async move { 74 | if let Err(cause) = Self::handle(msg).await { 75 | eprintln!("Error: {cause}"); 76 | } 77 | }); 78 | } 79 | } 80 | 81 | async fn handle(msg: ComputeMessage) -> io::Result<()> { 82 | match msg { 83 | ComputeMessage::Sqr(sqr) => { 84 | trace!("sqr {:?}", sqr); 85 | let WithChannels { 86 | tx, inner, span, .. 87 | } = sqr; 88 | let _entered = span.enter(); 89 | let result = (inner.num as u128) * (inner.num as u128); 90 | tx.send(result).await?; 91 | } 92 | ComputeMessage::Sum(sum) => { 93 | trace!("sum {:?}", sum); 94 | let WithChannels { rx, tx, span, .. } = sum; 95 | let _entered = span.enter(); 96 | let mut receiver = rx; 97 | let mut total = 0; 98 | while let Some(num) = receiver.recv().await? { 99 | total += num; 100 | } 101 | tx.send(total).await?; 102 | } 103 | ComputeMessage::Fibonacci(fib) => { 104 | trace!("fibonacci {:?}", fib); 105 | let WithChannels { 106 | tx, inner, span, .. 107 | } = fib; 108 | let _entered = span.enter(); 109 | let sender = tx; 110 | let mut a = 0u64; 111 | let mut b = 1u64; 112 | while a <= inner.max { 113 | sender.send(a).await?; 114 | let next = a + b; 115 | a = b; 116 | b = next; 117 | } 118 | } 119 | ComputeMessage::Multiply(mult) => { 120 | trace!("multiply {:?}", mult); 121 | let WithChannels { 122 | rx, 123 | tx, 124 | inner, 125 | span, 126 | .. 127 | } = mult; 128 | let _entered = span.enter(); 129 | let mut receiver = rx; 130 | let sender = tx; 131 | let multiplier = inner.initial; 132 | while let Some(num) = receiver.recv().await? { 133 | sender.send(multiplier * num).await?; 134 | } 135 | } 136 | } 137 | Ok(()) 138 | } 139 | } 140 | // The API for interacting with the ComputeService 141 | #[derive(Clone)] 142 | struct ComputeApi { 143 | inner: Client, 144 | } 145 | 146 | impl ComputeApi { 147 | pub fn connect(endpoint: quinn::Endpoint, addr: SocketAddr) -> anyhow::Result { 148 | Ok(ComputeApi { 149 | inner: Client::quinn(endpoint, addr), 150 | }) 151 | } 152 | 153 | pub fn listen(&self, endpoint: quinn::Endpoint) -> anyhow::Result> { 154 | let Some(local) = self.inner.as_local() else { 155 | bail!("cannot listen on a remote service"); 156 | }; 157 | let handler = ComputeProtocol::remote_handler(local); 158 | Ok(AbortOnDropHandle::new(task::spawn(listen( 159 | endpoint, handler, 160 | )))) 161 | } 162 | 163 | pub async fn sqr(&self, num: u64) -> anyhow::Result> { 164 | let msg = Sqr { num }; 165 | match self.inner.request().await? { 166 | Request::Local(request) => { 167 | let (tx, rx) = oneshot::channel(); 168 | request.send((msg, tx)).await?; 169 | Ok(rx) 170 | } 171 | Request::Remote(request) => { 172 | let (_tx, rx) = request.write(msg).await?; 173 | Ok(rx.into()) 174 | } 175 | } 176 | } 177 | 178 | pub async fn sum(&self) -> anyhow::Result<(mpsc::Sender, oneshot::Receiver)> { 179 | let msg = Sum; 180 | match self.inner.request().await? { 181 | Request::Local(request) => { 182 | let (num_tx, num_rx) = mpsc::channel(10); 183 | let (sum_tx, sum_rx) = oneshot::channel(); 184 | request.send((msg, sum_tx, num_rx)).await?; 185 | Ok((num_tx, sum_rx)) 186 | } 187 | Request::Remote(request) => { 188 | let (tx, rx) = request.write(msg).await?; 189 | Ok((tx.into(), rx.into())) 190 | } 191 | } 192 | } 193 | 194 | pub async fn fibonacci(&self, max: u64) -> anyhow::Result> { 195 | let msg = Fibonacci { max }; 196 | match self.inner.request().await? { 197 | Request::Local(request) => { 198 | let (tx, rx) = mpsc::channel(128); 199 | request.send((msg, tx)).await?; 200 | Ok(rx) 201 | } 202 | Request::Remote(request) => { 203 | let (_tx, rx) = request.write(msg).await?; 204 | Ok(rx.into()) 205 | } 206 | } 207 | } 208 | 209 | pub async fn multiply( 210 | &self, 211 | initial: u64, 212 | ) -> anyhow::Result<(mpsc::Sender, mpsc::Receiver)> { 213 | let msg = Multiply { initial }; 214 | match self.inner.request().await? { 215 | Request::Local(request) => { 216 | let (in_tx, in_rx) = mpsc::channel(128); 217 | let (out_tx, out_rx) = mpsc::channel(128); 218 | request.send((msg, out_tx, in_rx)).await?; 219 | Ok((in_tx, out_rx)) 220 | } 221 | Request::Remote(request) => { 222 | let (tx, rx) = request.write(msg).await?; 223 | Ok((tx.into(), rx.into())) 224 | } 225 | } 226 | } 227 | } 228 | 229 | // Local usage example 230 | async fn local() -> anyhow::Result<()> { 231 | let api = ComputeActor::local(); 232 | 233 | // Test Sqr 234 | let rx = api.sqr(5).await?; 235 | println!("Local: 5^2 = {}", rx.await?); 236 | 237 | // Test Sum 238 | let (tx, rx) = api.sum().await?; 239 | tx.send(1).await?; 240 | tx.send(2).await?; 241 | tx.send(3).await?; 242 | drop(tx); 243 | println!("Local: sum of [1, 2, 3] = {}", rx.await?); 244 | 245 | // Test Fibonacci 246 | let mut rx = api.fibonacci(10).await?; 247 | print!("Local: Fibonacci up to 10 = "); 248 | while let Some(num) = rx.recv().await? { 249 | print!("{num} "); 250 | } 251 | println!(); 252 | 253 | // Test Multiply 254 | let (in_tx, mut out_rx) = api.multiply(3).await?; 255 | in_tx.send(2).await?; 256 | in_tx.send(4).await?; 257 | in_tx.send(6).await?; 258 | drop(in_tx); 259 | print!("Local: 3 * [2, 4, 6] = "); 260 | while let Some(num) = out_rx.recv().await? { 261 | print!("{num} "); 262 | } 263 | println!(); 264 | 265 | Ok(()) 266 | } 267 | 268 | fn remote_api() -> anyhow::Result<(ComputeApi, AbortOnDropHandle<()>)> { 269 | let port = 10114; 270 | let (server, cert) = 271 | make_server_endpoint(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port).into())?; 272 | let client = 273 | make_client_endpoint(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(), &[&cert])?; 274 | let compute = ComputeActor::local(); 275 | let handle = compute.listen(server)?; 276 | let api = ComputeApi::connect(client, SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into())?; 277 | Ok((api, handle)) 278 | } 279 | 280 | // Remote usage example 281 | async fn remote() -> anyhow::Result<()> { 282 | let (api, handle) = remote_api()?; 283 | 284 | // Test Sqr 285 | let rx = api.sqr(4).await?; 286 | println!("Remote: 4^2 = {}", rx.await?); 287 | 288 | // Test Sum 289 | let (tx, rx) = api.sum().await?; 290 | tx.send(4).await?; 291 | tx.send(5).await?; 292 | tx.send(6).await?; 293 | drop(tx); 294 | println!("Remote: sum of [4, 5, 6] = {}", rx.await?); 295 | 296 | // Test Fibonacci 297 | let mut rx = api.fibonacci(20).await?; 298 | print!("Remote: Fibonacci up to 20 = "); 299 | while let Some(num) = rx.recv().await? { 300 | print!("{num} "); 301 | } 302 | println!(); 303 | 304 | // Test Multiply 305 | let (in_tx, mut out_rx) = api.multiply(5).await?; 306 | in_tx.send(1).await?; 307 | in_tx.send(2).await?; 308 | in_tx.send(3).await?; 309 | drop(in_tx); 310 | print!("Remote: 5 * [1, 2, 3] = "); 311 | while let Some(num) = out_rx.recv().await? { 312 | print!("{num} "); 313 | } 314 | println!(); 315 | 316 | drop(handle); 317 | Ok(()) 318 | } 319 | 320 | // Benchmark function using the new ComputeApi 321 | async fn bench(api: ComputeApi, n: u64) -> anyhow::Result<()> { 322 | // Individual RPCs (sequential) 323 | { 324 | let mut sum = 0; 325 | let t0 = std::time::Instant::now(); 326 | for i in 0..n { 327 | sum += api.sqr(i).await?.await?; 328 | if i % 10000 == 0 { 329 | print!("."); 330 | io::stdout().flush()?; 331 | } 332 | } 333 | let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round() as u64; 334 | assert_eq!(sum, sum_of_squares(n)); 335 | clear_line()?; 336 | println!("RPC seq {} rps", rps.separate_with_underscores()); 337 | } 338 | 339 | // Parallel RPCs 340 | { 341 | let t0 = std::time::Instant::now(); 342 | let api = api.clone(); 343 | let reqs = n0_future::stream::iter((0..n).map(move |i| { 344 | let api = api.clone(); 345 | async move { anyhow::Ok(api.sqr(i).await?.await?) } 346 | })); 347 | let resp: Vec<_> = reqs.buffered_unordered(32).try_collect().await?; 348 | let sum = resp.into_iter().sum::(); 349 | let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round() as u64; 350 | assert_eq!(sum, sum_of_squares(n)); 351 | clear_line()?; 352 | println!("RPC par {} rps", rps.separate_with_underscores()); 353 | } 354 | 355 | // Sequential streaming (using Multiply instead of MultiplyUpdate) 356 | { 357 | let t0 = std::time::Instant::now(); 358 | let (send, mut recv) = api.multiply(2).await?; 359 | let handle = tokio::task::spawn(async move { 360 | for i in 0..n { 361 | send.send(i).await?; 362 | } 363 | Ok::<(), io::Error>(()) 364 | }); 365 | let mut sum = 0; 366 | let mut i = 0; 367 | while let Some(res) = recv.recv().await? { 368 | sum += res; 369 | if i % 10000 == 0 { 370 | print!("."); 371 | io::stdout().flush()?; 372 | } 373 | i += 1; 374 | } 375 | let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round() as u64; 376 | assert_eq!(sum, (0..n).map(|x| x * 2).sum::()); 377 | clear_line()?; 378 | println!("bidi seq {} rps", rps.separate_with_underscores()); 379 | handle.await??; 380 | } 381 | 382 | Ok(()) 383 | } 384 | 385 | // Helper function to compute the sum of squares 386 | fn sum_of_squares(n: u64) -> u128 { 387 | (0..n).map(|x| (x * x) as u128).sum() 388 | } 389 | 390 | // Helper function to clear the current line 391 | fn clear_line() -> io::Result<()> { 392 | io::stdout().write_all(b"\r\x1b[K")?; 393 | io::stdout().flush()?; 394 | Ok(()) 395 | } 396 | 397 | // Simple benchmark sending oneshot senders via an mpsc channel 398 | pub async fn reference_bench(n: u64) -> anyhow::Result<()> { 399 | // Create an mpsc channel to send oneshot senders 400 | let (tx, mut rx) = tokio::sync::mpsc::channel::>(32); 401 | 402 | // Spawn a task to respond to all oneshot senders 403 | tokio::spawn(async move { 404 | while let Some(sender) = rx.recv().await { 405 | // Immediately send a fixed response (42) back through the oneshot sender 406 | sender.send(42).ok(); 407 | } 408 | Ok::<(), io::Error>(()) 409 | }); 410 | 411 | // Sequential oneshot sends 412 | { 413 | let mut sum = 0; 414 | let t0 = std::time::Instant::now(); 415 | for i in 0..n { 416 | let (send, recv) = tokio::sync::oneshot::channel(); 417 | tx.send(send).await?; 418 | sum += recv.await?; 419 | if i % 10000 == 0 { 420 | print!("."); 421 | io::stdout().flush()?; 422 | } 423 | } 424 | let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round() as u64; 425 | assert_eq!(sum, 42 * n); // Each response is 42 426 | clear_line()?; 427 | println!("Reference seq {} rps", rps.separate_with_underscores()); 428 | } 429 | 430 | // Parallel oneshot sends 431 | { 432 | let t0 = std::time::Instant::now(); 433 | let reqs = n0_future::stream::iter((0..n).map(|_| async { 434 | let (send, recv) = tokio::sync::oneshot::channel(); 435 | tx.send(send).await?; 436 | anyhow::Ok(recv.await?) 437 | })); 438 | let resp: Vec<_> = reqs.buffered_unordered(32).try_collect().await?; 439 | let sum = resp.into_iter().sum::(); 440 | let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round() as u64; 441 | assert_eq!(sum, 42 * n); // Each response is 42 442 | clear_line()?; 443 | println!("Reference par {} rps", rps.separate_with_underscores()); 444 | } 445 | 446 | Ok(()) 447 | } 448 | 449 | #[tokio::main] 450 | async fn main() -> anyhow::Result<()> { 451 | tracing_subscriber::fmt::init(); 452 | println!("Local use"); 453 | local().await?; 454 | println!("Remote use"); 455 | remote().await?; 456 | 457 | println!("Local bench"); 458 | let api = ComputeActor::local(); 459 | bench(api, 100000).await?; 460 | 461 | let (api, handle) = remote_api()?; 462 | println!("Remote bench"); 463 | bench(api, 100000).await?; 464 | drop(handle); 465 | 466 | println!("Reference bench"); 467 | reference_bench(100000).await?; 468 | Ok(()) 469 | } 470 | -------------------------------------------------------------------------------- /irpc-iroh/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fmt, 3 | future::Future, 4 | io, 5 | sync::{atomic::AtomicU64, Arc}, 6 | }; 7 | 8 | use iroh::{ 9 | endpoint::{ 10 | Accepting, ConnectingError, Connection, ConnectionError, IncomingZeroRttConnection, 11 | OutgoingZeroRttConnection, RecvStream, RemoteEndpointIdError, SendStream, VarInt, 12 | ZeroRttStatus, 13 | }, 14 | protocol::{AcceptError, ProtocolHandler}, 15 | EndpointId, 16 | }; 17 | use irpc::{ 18 | channel::oneshot, 19 | rpc::{ 20 | Handler, RemoteConnection, RemoteService, ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED, 21 | MAX_MESSAGE_SIZE, 22 | }, 23 | util::AsyncReadVarintExt, 24 | LocalSender, RequestError, 25 | }; 26 | use n0_error::{e, Result}; 27 | use n0_future::{future::Boxed as BoxFuture, TryFutureExt}; 28 | use serde::de::DeserializeOwned; 29 | use tracing::{debug, error_span, trace, trace_span, warn, Instrument}; 30 | 31 | /// Returns a client that connects to a irpc service using an [`iroh::Endpoint`]. 32 | pub fn client( 33 | endpoint: iroh::Endpoint, 34 | addr: impl Into, 35 | alpn: impl AsRef<[u8]>, 36 | ) -> irpc::Client { 37 | let conn = IrohLazyRemoteConnection::new(endpoint, addr.into(), alpn.as_ref().to_vec()); 38 | irpc::Client::boxed(conn) 39 | } 40 | 41 | /// Wrap an existing iroh connection as an irpc remote connection. 42 | /// 43 | /// This will stop working as soon as the underlying iroh connection is closed. 44 | /// If you need to support reconnects, use [`IrohLazyRemoteConnection`] instead. 45 | // TODO: remove this and provide a From instance as soon as iroh is 1.0 and 46 | // we can move irpc-iroh into irpc? 47 | #[derive(Debug, Clone)] 48 | pub struct IrohRemoteConnection(Connection); 49 | 50 | impl IrohRemoteConnection { 51 | pub fn new(connection: Connection) -> Self { 52 | Self(connection) 53 | } 54 | } 55 | 56 | impl irpc::rpc::RemoteConnection for IrohRemoteConnection { 57 | fn clone_boxed(&self) -> Box { 58 | Box::new(self.clone()) 59 | } 60 | 61 | fn open_bi( 62 | &self, 63 | ) -> n0_future::future::Boxed> 64 | { 65 | let conn = self.0.clone(); 66 | Box::pin(async move { 67 | let (send, recv) = conn.open_bi().await?; 68 | Ok((send, recv)) 69 | }) 70 | } 71 | 72 | fn zero_rtt_accepted(&self) -> BoxFuture { 73 | Box::pin(async { true }) 74 | } 75 | } 76 | 77 | #[derive(Debug, Clone)] 78 | pub struct IrohZrttRemoteConnection(OutgoingZeroRttConnection); 79 | 80 | impl IrohZrttRemoteConnection { 81 | pub fn new(connection: OutgoingZeroRttConnection) -> Self { 82 | Self(connection) 83 | } 84 | } 85 | 86 | impl irpc::rpc::RemoteConnection for IrohZrttRemoteConnection { 87 | fn clone_boxed(&self) -> Box { 88 | Box::new(self.clone()) 89 | } 90 | 91 | fn open_bi( 92 | &self, 93 | ) -> n0_future::future::Boxed> 94 | { 95 | let conn = self.0.clone(); 96 | Box::pin(async move { 97 | let (send, recv) = conn.open_bi().await?; 98 | Ok((send, recv)) 99 | }) 100 | } 101 | 102 | fn zero_rtt_accepted(&self) -> BoxFuture { 103 | let conn = self.0.clone(); 104 | Box::pin(async move { 105 | match conn.handshake_completed().await { 106 | Err(_) => false, 107 | Ok(ZeroRttStatus::Accepted(_)) => true, 108 | Ok(ZeroRttStatus::Rejected(_)) => false, 109 | } 110 | }) 111 | } 112 | } 113 | 114 | /// A connection to a remote service. 115 | /// 116 | /// Initially this does just have the endpoint and the address. Once a 117 | /// connection is established, it will be stored. 118 | #[derive(Debug, Clone)] 119 | pub struct IrohLazyRemoteConnection(Arc); 120 | 121 | #[derive(Debug)] 122 | struct IrohRemoteConnectionInner { 123 | endpoint: iroh::Endpoint, 124 | addr: iroh::EndpointAddr, 125 | connection: tokio::sync::Mutex>, 126 | alpn: Vec, 127 | } 128 | 129 | impl IrohLazyRemoteConnection { 130 | pub fn new(endpoint: iroh::Endpoint, addr: iroh::EndpointAddr, alpn: Vec) -> Self { 131 | Self(Arc::new(IrohRemoteConnectionInner { 132 | endpoint, 133 | addr, 134 | connection: Default::default(), 135 | alpn, 136 | })) 137 | } 138 | } 139 | 140 | impl RemoteConnection for IrohLazyRemoteConnection { 141 | fn clone_boxed(&self) -> Box { 142 | Box::new(self.clone()) 143 | } 144 | 145 | fn open_bi(&self) -> BoxFuture> { 146 | let this = self.0.clone(); 147 | Box::pin(async move { 148 | let mut guard = this.connection.lock().await; 149 | let pair = match guard.as_mut() { 150 | Some(conn) => { 151 | // try to reuse the connection 152 | match conn.open_bi().await { 153 | Ok(pair) => pair, 154 | Err(_) => { 155 | // try with a new connection, just once 156 | *guard = None; 157 | connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard) 158 | .await? 159 | } 160 | } 161 | } 162 | None => connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard).await?, 163 | }; 164 | Ok(pair) 165 | }) 166 | } 167 | 168 | fn zero_rtt_accepted(&self) -> BoxFuture { 169 | Box::pin(async { true }) 170 | } 171 | } 172 | 173 | async fn connect_and_open_bi( 174 | endpoint: &iroh::Endpoint, 175 | addr: &iroh::EndpointAddr, 176 | alpn: &[u8], 177 | mut guard: tokio::sync::MutexGuard<'_, Option>, 178 | ) -> Result<(SendStream, RecvStream), RequestError> { 179 | let conn = endpoint 180 | .connect(addr.clone(), alpn) 181 | .await 182 | .map_err(|err| e!(RequestError::Other, err.into()))?; 183 | let (send, recv) = conn.open_bi().await?; 184 | *guard = Some(conn); 185 | Ok((send, recv)) 186 | } 187 | 188 | /// A [`ProtocolHandler`] for an irpc protocol. 189 | /// 190 | /// Can be added to an [`iroh::protocol::Router`] to handle incoming connections for an ALPN string. 191 | pub struct IrohProtocol { 192 | handler: Handler, 193 | request_id: AtomicU64, 194 | } 195 | 196 | impl fmt::Debug for IrohProtocol { 197 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 198 | write!(f, "RpcProtocol") 199 | } 200 | } 201 | 202 | impl IrohProtocol { 203 | pub fn with_sender(local_sender: impl Into>) -> Self 204 | where 205 | R: RemoteService, 206 | { 207 | let handler = R::remote_handler(local_sender.into()); 208 | Self::new(handler) 209 | } 210 | 211 | /// Creates a new [`IrohProtocol`] for the `handler`. 212 | pub fn new(handler: Handler) -> Self { 213 | Self { 214 | handler, 215 | request_id: Default::default(), 216 | } 217 | } 218 | } 219 | 220 | impl ProtocolHandler for IrohProtocol { 221 | async fn accept(&self, connection: Connection) -> Result<(), AcceptError> { 222 | let handler = self.handler.clone(); 223 | let request_id = self 224 | .request_id 225 | .fetch_add(1, std::sync::atomic::Ordering::AcqRel); 226 | let fut = handle_connection(&connection, handler).map_err(AcceptError::from_err); 227 | let span = trace_span!("rpc", id = request_id); 228 | fut.instrument(span).await 229 | } 230 | } 231 | 232 | /// A [`ProtocolHandler`] for an irpc protocol that supports 0rtt connections. 233 | /// 234 | /// Can be added to an [`iroh::protocol::Router`] to handle incoming connections for an ALPN string. 235 | /// 236 | /// For details about when it is safe to use 0rtt, see https://www.iroh.computer/blog/0rtt-api 237 | pub struct Iroh0RttProtocol { 238 | handler: Handler, 239 | request_id: AtomicU64, 240 | } 241 | 242 | impl fmt::Debug for Iroh0RttProtocol { 243 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 244 | write!(f, "RpcProtocol") 245 | } 246 | } 247 | 248 | impl Iroh0RttProtocol { 249 | pub fn with_sender(local_sender: impl Into>) -> Self 250 | where 251 | R: RemoteService, 252 | { 253 | let handler = R::remote_handler(local_sender.into()); 254 | Self::new(handler) 255 | } 256 | 257 | /// Creates a new [`Iroh0RttProtocol`] for the `handler`. 258 | pub fn new(handler: Handler) -> Self { 259 | Self { 260 | handler, 261 | request_id: Default::default(), 262 | } 263 | } 264 | } 265 | 266 | impl ProtocolHandler for Iroh0RttProtocol { 267 | async fn on_accepting(&self, accepting: Accepting) -> Result { 268 | let zrtt_conn = accepting.into_0rtt(); 269 | let handler = self.handler.clone(); 270 | let request_id = self 271 | .request_id 272 | .fetch_add(1, std::sync::atomic::Ordering::AcqRel); 273 | handle_connection(&zrtt_conn, handler) 274 | .map_err(AcceptError::from_err) 275 | .instrument(trace_span!("rpc", id = request_id)) 276 | .await?; 277 | let conn = zrtt_conn 278 | .handshake_completed() 279 | .await 280 | .map_err(|err| AcceptError::from(ConnectingError::from(err)))?; 281 | Ok(conn) 282 | } 283 | 284 | async fn accept(&self, _connection: Connection) -> Result<(), AcceptError> { 285 | // Noop, handled in [`Self::on_accepting`] 286 | Ok(()) 287 | } 288 | } 289 | 290 | /// Handles a single iroh connection with the provided `handler`. 291 | pub async fn handle_connection( 292 | connection: &impl IncomingRemoteConnection, 293 | handler: Handler, 294 | ) -> io::Result<()> { 295 | if let Ok(remote) = connection.remote_id() { 296 | tracing::Span::current().record("remote", tracing::field::display(remote.fmt_short())); 297 | } 298 | debug!("connection accepted"); 299 | loop { 300 | let Some((msg, rx, tx)) = read_request_raw(connection).await? else { 301 | return Ok(()); 302 | }; 303 | handler(msg, rx, tx).await?; 304 | } 305 | } 306 | 307 | /// Reads a single request from a connection, and a message with channels. 308 | pub async fn read_request( 309 | connection: &impl IncomingRemoteConnection, 310 | ) -> std::io::Result> { 311 | Ok(read_request_raw::(connection) 312 | .await? 313 | .map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx))) 314 | } 315 | 316 | /// Abstracts over [`Connection`] and [`IncomingZeroRttConnection`]. 317 | /// 318 | /// You don't need to implement this trait yourself. It is used by [`read_request`] and 319 | /// [`handle_connection`] to work with both fully authenticated connections and with 320 | /// 0-RTT connections. 321 | pub trait IncomingRemoteConnection { 322 | /// Accepts a single bidirectional stream. 323 | fn accept_bi( 324 | &self, 325 | ) -> impl Future> + Send; 326 | /// Close the connection. 327 | fn close(&self, error_code: VarInt, reason: &[u8]); 328 | /// Returns the remote's endpoint id. 329 | /// 330 | /// This may only fail for 0-RTT connections. 331 | fn remote_id(&self) -> Result; 332 | } 333 | 334 | impl IncomingRemoteConnection for IncomingZeroRttConnection { 335 | async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> { 336 | self.accept_bi().await 337 | } 338 | 339 | fn close(&self, error_code: VarInt, reason: &[u8]) { 340 | self.close(error_code, reason) 341 | } 342 | fn remote_id(&self) -> Result { 343 | self.remote_id() 344 | } 345 | } 346 | 347 | impl IncomingRemoteConnection for Connection { 348 | async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> { 349 | self.accept_bi().await 350 | } 351 | 352 | fn close(&self, error_code: VarInt, reason: &[u8]) { 353 | self.close(error_code, reason) 354 | } 355 | fn remote_id(&self) -> Result { 356 | Ok(self.remote_id()) 357 | } 358 | } 359 | 360 | /// Reads a single request from the connection. 361 | /// 362 | /// This accepts a bi-directional stream from the connection and reads and parses the request. 363 | /// 364 | /// Returns the parsed request and the stream pair if reading and parsing the request succeeded. 365 | /// Returns None if the remote closed the connection with error code `0`. 366 | /// Returns an error for all other failure cases. 367 | pub async fn read_request_raw( 368 | connection: &impl IncomingRemoteConnection, 369 | ) -> std::io::Result> { 370 | let (send, mut recv) = match connection.accept_bi().await { 371 | Ok((s, r)) => (s, r), 372 | Err(ConnectionError::ApplicationClosed(cause)) if cause.error_code.into_inner() == 0 => { 373 | trace!("remote side closed connection {cause:?}"); 374 | return Ok(None); 375 | } 376 | Err(cause) => { 377 | warn!("failed to accept bi stream {cause:?}"); 378 | return Err(cause.into()); 379 | } 380 | }; 381 | let size = recv 382 | .read_varint_u64() 383 | .await? 384 | .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?; 385 | if size > MAX_MESSAGE_SIZE { 386 | connection.close( 387 | ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(), 388 | b"request exceeded max message size", 389 | ); 390 | return Err(e!(oneshot::RecvError::MaxMessageSizeExceeded).into()); 391 | } 392 | let mut buf = vec![0; size as usize]; 393 | recv.read_exact(&mut buf) 394 | .await 395 | .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; 396 | let msg: R = 397 | postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; 398 | let rx = recv; 399 | let tx = send; 400 | Ok(Some((msg, rx, tx))) 401 | } 402 | 403 | /// Utility function to listen for incoming connections and handle them with the provided handler 404 | pub async fn listen(endpoint: iroh::Endpoint, handler: Handler) { 405 | let mut request_id = 0u64; 406 | let mut tasks = n0_future::task::JoinSet::new(); 407 | loop { 408 | let incoming = tokio::select! { 409 | Some(res) = tasks.join_next(), if !tasks.is_empty() => { 410 | res.expect("irpc connection task panicked"); 411 | continue; 412 | } 413 | incoming = endpoint.accept() => { 414 | match incoming { 415 | None => break, 416 | Some(incoming) => incoming 417 | } 418 | } 419 | }; 420 | let handler = handler.clone(); 421 | let fut = async move { 422 | match incoming.await { 423 | Ok(connection) => match handle_connection(&connection, handler).await { 424 | Err(err) => warn!("connection closed with error: {err:?}"), 425 | Ok(()) => debug!("connection closed"), 426 | }, 427 | Err(cause) => { 428 | warn!("failed to accept connection: {cause:?}"); 429 | } 430 | }; 431 | }; 432 | let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty); 433 | tasks.spawn(fut.instrument(span)); 434 | request_id += 1; 435 | } 436 | } 437 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | //! Utilities 2 | //! 3 | //! This module contains utilities to read and write varints, as well as 4 | //! functions to set up quinn endpoints for local rpc and testing. 5 | #[cfg(feature = "quinn_endpoint_setup")] 6 | #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "quinn_endpoint_setup")))] 7 | mod quinn_setup_utils { 8 | use std::{sync::Arc, time::Duration}; 9 | 10 | use n0_error::{Result, StdResultExt}; 11 | use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, ServerConfig}; 12 | 13 | /// Create a quinn client config and trusts given certificates. 14 | /// 15 | /// ## Args 16 | /// 17 | /// - server_certs: a list of trusted certificates in DER format. 18 | pub fn configure_client(server_certs: &[&[u8]]) -> Result { 19 | let mut certs = rustls::RootCertStore::empty(); 20 | for cert in server_certs { 21 | let cert = rustls::pki_types::CertificateDer::from(cert.to_vec()); 22 | certs.add(cert).std_context("Error configuring certs")?; 23 | } 24 | 25 | let provider = rustls::crypto::ring::default_provider(); 26 | let crypto_client_config = rustls::ClientConfig::builder_with_provider(Arc::new(provider)) 27 | .with_protocol_versions(&[&rustls::version::TLS13]) 28 | .expect("valid versions") 29 | .with_root_certificates(certs) 30 | .with_no_client_auth(); 31 | let quic_client_config = 32 | quinn::crypto::rustls::QuicClientConfig::try_from(crypto_client_config) 33 | .std_context("Error creating QUIC client config")?; 34 | 35 | let mut transport_config = quinn::TransportConfig::default(); 36 | transport_config.keep_alive_interval(Some(Duration::from_secs(1))); 37 | let mut client_config = ClientConfig::new(Arc::new(quic_client_config)); 38 | client_config.transport_config(Arc::new(transport_config)); 39 | Ok(client_config) 40 | } 41 | 42 | /// Create a quinn server config with a self-signed certificate 43 | /// 44 | /// Returns the server config and the certificate in DER format 45 | pub fn configure_server() -> Result<(ServerConfig, Vec)> { 46 | let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]) 47 | .std_context("Error generating self-signed cert")?; 48 | let cert_der = cert.cert.der(); 49 | let priv_key = 50 | rustls::pki_types::PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der()); 51 | let cert_chain = vec![cert_der.clone()]; 52 | 53 | let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key.into()) 54 | .std_context("Error creating server config")?; 55 | Arc::get_mut(&mut server_config.transport) 56 | .unwrap() 57 | .max_concurrent_uni_streams(0_u8.into()); 58 | 59 | Ok((server_config, cert_der.to_vec())) 60 | } 61 | 62 | /// Create a quinn client config and trust all certificates. 63 | pub fn configure_client_insecure() -> Result { 64 | let provider = rustls::crypto::ring::default_provider(); 65 | let crypto = rustls::ClientConfig::builder_with_provider(Arc::new(provider)) 66 | .with_protocol_versions(rustls::DEFAULT_VERSIONS) 67 | .expect("valid versions") 68 | .dangerous() 69 | .with_custom_certificate_verifier(Arc::new(SkipServerVerification)) 70 | .with_no_client_auth(); 71 | let client_cfg = 72 | QuicClientConfig::try_from(crypto).std_context("Error creating QUIC client config")?; 73 | let client_cfg = ClientConfig::new(Arc::new(client_cfg)); 74 | Ok(client_cfg) 75 | } 76 | 77 | #[cfg(not(target_arch = "wasm32"))] 78 | mod non_wasm { 79 | use std::net::SocketAddr; 80 | 81 | use quinn::Endpoint; 82 | 83 | use super::*; 84 | 85 | /// Constructs a QUIC endpoint configured for use a client only. 86 | /// 87 | /// ## Args 88 | /// 89 | /// - server_certs: list of trusted certificates. 90 | pub fn make_client_endpoint( 91 | bind_addr: SocketAddr, 92 | server_certs: &[&[u8]], 93 | ) -> Result { 94 | let client_cfg = configure_client(server_certs)?; 95 | let mut endpoint = Endpoint::client(bind_addr)?; 96 | endpoint.set_default_client_config(client_cfg); 97 | Ok(endpoint) 98 | } 99 | 100 | /// Constructs a QUIC endpoint configured for use a client only that trusts all certificates. 101 | /// 102 | /// This is useful for testing and local connections, but should be used with care. 103 | pub fn make_insecure_client_endpoint(bind_addr: SocketAddr) -> Result { 104 | let client_cfg = configure_client_insecure()?; 105 | let mut endpoint = Endpoint::client(bind_addr)?; 106 | endpoint.set_default_client_config(client_cfg); 107 | Ok(endpoint) 108 | } 109 | 110 | /// Constructs a QUIC server endpoint with a self-signed certificate 111 | /// 112 | /// Returns the server endpoint and the certificate in DER format 113 | pub fn make_server_endpoint(bind_addr: SocketAddr) -> Result<(Endpoint, Vec)> { 114 | let (server_config, server_cert) = configure_server()?; 115 | let endpoint = Endpoint::server(server_config, bind_addr)?; 116 | Ok((endpoint, server_cert)) 117 | } 118 | } 119 | 120 | #[cfg(not(target_arch = "wasm32"))] 121 | pub use non_wasm::{make_client_endpoint, make_insecure_client_endpoint, make_server_endpoint}; 122 | 123 | #[derive(Debug)] 124 | struct SkipServerVerification; 125 | 126 | impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { 127 | fn verify_server_cert( 128 | &self, 129 | _end_entity: &rustls::pki_types::CertificateDer<'_>, 130 | _intermediates: &[rustls::pki_types::CertificateDer<'_>], 131 | _server_name: &rustls::pki_types::ServerName<'_>, 132 | _ocsp_response: &[u8], 133 | _now: rustls::pki_types::UnixTime, 134 | ) -> Result { 135 | Ok(rustls::client::danger::ServerCertVerified::assertion()) 136 | } 137 | 138 | fn verify_tls12_signature( 139 | &self, 140 | _message: &[u8], 141 | _cert: &rustls::pki_types::CertificateDer<'_>, 142 | _dss: &rustls::DigitallySignedStruct, 143 | ) -> Result { 144 | Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) 145 | } 146 | 147 | fn verify_tls13_signature( 148 | &self, 149 | _message: &[u8], 150 | _cert: &rustls::pki_types::CertificateDer<'_>, 151 | _dss: &rustls::DigitallySignedStruct, 152 | ) -> Result { 153 | Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) 154 | } 155 | 156 | fn supported_verify_schemes(&self) -> Vec { 157 | use rustls::SignatureScheme::*; 158 | // list them all, we don't care. 159 | vec![ 160 | RSA_PKCS1_SHA1, 161 | ECDSA_SHA1_Legacy, 162 | RSA_PKCS1_SHA256, 163 | ECDSA_NISTP256_SHA256, 164 | RSA_PKCS1_SHA384, 165 | ECDSA_NISTP384_SHA384, 166 | RSA_PKCS1_SHA512, 167 | ECDSA_NISTP521_SHA512, 168 | RSA_PSS_SHA256, 169 | RSA_PSS_SHA384, 170 | RSA_PSS_SHA512, 171 | ED25519, 172 | ED448, 173 | ] 174 | } 175 | } 176 | } 177 | #[cfg(feature = "quinn_endpoint_setup")] 178 | #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "quinn_endpoint_setup")))] 179 | pub use quinn_setup_utils::*; 180 | 181 | #[cfg(any(feature = "rpc", feature = "varint-util"))] 182 | #[cfg_attr( 183 | quicrpc_docsrs, 184 | doc(cfg(any(feature = "rpc", feature = "varint-util"))) 185 | )] 186 | mod varint_util { 187 | use std::{ 188 | future::Future, 189 | io::{self, Error}, 190 | }; 191 | 192 | use serde::{de::DeserializeOwned, Serialize}; 193 | use smallvec::SmallVec; 194 | use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; 195 | 196 | /// Reads a u64 varint from an AsyncRead source, using the Postcard/LEB128 format. 197 | /// 198 | /// In Postcard's varint format (LEB128): 199 | /// - Each byte uses 7 bits for the value 200 | /// - The MSB (most significant bit) of each byte indicates if there are more bytes (1) or not (0) 201 | /// - Values are stored in little-endian order (least significant group first) 202 | /// 203 | /// Returns the decoded u64 value. 204 | pub async fn read_varint_u64(reader: &mut R) -> io::Result> 205 | where 206 | R: AsyncRead + Unpin, 207 | { 208 | let mut result: u64 = 0; 209 | let mut shift: u32 = 0; 210 | 211 | loop { 212 | // We can only shift up to 63 bits (for a u64) 213 | if shift >= 64 { 214 | return Err(io::Error::new( 215 | io::ErrorKind::InvalidData, 216 | "Varint is too large for u64", 217 | )); 218 | } 219 | 220 | // Read a single byte 221 | let res = reader.read_u8().await; 222 | if shift == 0 { 223 | if let Err(cause) = res { 224 | if cause.kind() == io::ErrorKind::UnexpectedEof { 225 | return Ok(None); 226 | } else { 227 | return Err(cause); 228 | } 229 | } 230 | } 231 | 232 | let byte = res?; 233 | 234 | // Extract the 7 value bits (bits 0-6, excluding the MSB which is the continuation bit) 235 | let value = (byte & 0x7F) as u64; 236 | 237 | // Add the bits to our result at the current shift position 238 | result |= value << shift; 239 | 240 | // If the high bit is not set (0), this is the last byte 241 | if byte & 0x80 == 0 { 242 | break; 243 | } 244 | 245 | // Move to the next 7 bits 246 | shift += 7; 247 | } 248 | 249 | Ok(Some(result)) 250 | } 251 | 252 | /// Writes a u64 varint to any object that implements the `std::io::Write` trait. 253 | /// 254 | /// This encodes the value using LEB128 encoding. 255 | /// 256 | /// # Arguments 257 | /// * `writer` - Any object implementing `std::io::Write` 258 | /// * `value` - The u64 value to encode as a varint 259 | /// 260 | /// # Returns 261 | /// The number of bytes written or an IO error 262 | pub fn write_varint_u64_sync( 263 | writer: &mut W, 264 | value: u64, 265 | ) -> std::io::Result { 266 | // Handle zero as a special case 267 | if value == 0 { 268 | writer.write_all(&[0])?; 269 | return Ok(1); 270 | } 271 | 272 | let mut bytes_written = 0; 273 | let mut remaining = value; 274 | 275 | while remaining > 0 { 276 | // Extract the 7 least significant bits 277 | let mut byte = (remaining & 0x7F) as u8; 278 | remaining >>= 7; 279 | 280 | // Set the continuation bit if there's more data 281 | if remaining > 0 { 282 | byte |= 0x80; 283 | } 284 | 285 | writer.write_all(&[byte])?; 286 | bytes_written += 1; 287 | } 288 | 289 | Ok(bytes_written) 290 | } 291 | 292 | pub fn write_length_prefixed( 293 | mut write: impl std::io::Write, 294 | value: T, 295 | ) -> io::Result<()> { 296 | let size = postcard::experimental::serialized_size(&value) 297 | .map_err(|e| Error::new(io::ErrorKind::InvalidData, e))? as u64; 298 | write_varint_u64_sync(&mut write, size)?; 299 | postcard::to_io(&value, &mut write) 300 | .map_err(|e| Error::new(io::ErrorKind::InvalidData, e))?; 301 | Ok(()) 302 | } 303 | 304 | /// Provides a fn to read a varint from an AsyncRead source. 305 | pub trait AsyncReadVarintExt: AsyncRead + Unpin { 306 | /// Reads a u64 varint from an AsyncRead source, using the Postcard/LEB128 format. 307 | /// 308 | /// If the stream is at the end, this returns `Ok(None)`. 309 | fn read_varint_u64(&mut self) -> impl Future>>; 310 | 311 | fn read_length_prefixed( 312 | &mut self, 313 | max_size: usize, 314 | ) -> impl Future>; 315 | } 316 | 317 | impl AsyncReadVarintExt for T { 318 | fn read_varint_u64(&mut self) -> impl Future>> { 319 | read_varint_u64(self) 320 | } 321 | 322 | async fn read_length_prefixed( 323 | &mut self, 324 | max_size: usize, 325 | ) -> io::Result { 326 | let size = match self.read_varint_u64().await? { 327 | Some(size) => size, 328 | None => return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "EOF reached")), 329 | }; 330 | 331 | if size > max_size as u64 { 332 | return Err(io::Error::new( 333 | io::ErrorKind::InvalidData, 334 | "Length-prefixed value too large", 335 | )); 336 | } 337 | 338 | let mut buf = vec![0; size as usize]; 339 | self.read_exact(&mut buf).await?; 340 | postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) 341 | } 342 | } 343 | 344 | /// Provides a fn to write a varint to an [`io::Write`] target, as well as a 345 | /// helper to write a length-prefixed value. 346 | pub trait WriteVarintExt: std::io::Write { 347 | /// Write a varint 348 | #[allow(dead_code)] 349 | fn write_varint_u64(&mut self, value: u64) -> io::Result; 350 | /// Write a value with a varint encoded length prefix. 351 | fn write_length_prefixed(&mut self, value: T) -> io::Result<()>; 352 | } 353 | 354 | impl WriteVarintExt for T { 355 | fn write_varint_u64(&mut self, value: u64) -> io::Result { 356 | write_varint_u64_sync(self, value) 357 | } 358 | 359 | fn write_length_prefixed(&mut self, value: V) -> io::Result<()> { 360 | write_length_prefixed(self, value) 361 | } 362 | } 363 | 364 | /// Provides a fn to write a varint to an [`io::Write`] target, as well as a 365 | /// helper to write a length-prefixed value. 366 | pub trait AsyncWriteVarintExt: AsyncWrite + Unpin { 367 | /// Write a varint 368 | fn write_varint_u64(&mut self, value: u64) -> impl Future>; 369 | /// Write a value with a varint encoded length prefix. 370 | fn write_length_prefixed( 371 | &mut self, 372 | value: T, 373 | ) -> impl Future>; 374 | } 375 | 376 | impl AsyncWriteVarintExt for T { 377 | async fn write_varint_u64(&mut self, value: u64) -> io::Result { 378 | let mut buf: SmallVec<[u8; 10]> = Default::default(); 379 | write_varint_u64_sync(&mut buf, value).unwrap(); 380 | self.write_all(&buf[..]).await?; 381 | Ok(buf.len()) 382 | } 383 | 384 | async fn write_length_prefixed(&mut self, value: V) -> io::Result { 385 | let mut buf = Vec::new(); 386 | write_length_prefixed(&mut buf, value)?; 387 | let size = buf.len(); 388 | self.write_all(&buf).await?; 389 | Ok(size) 390 | } 391 | } 392 | } 393 | 394 | #[cfg(any(feature = "rpc", feature = "varint-util"))] 395 | #[cfg_attr( 396 | quicrpc_docsrs, 397 | doc(cfg(any(feature = "rpc", feature = "varint-util"))) 398 | )] 399 | pub use varint_util::{AsyncReadVarintExt, AsyncWriteVarintExt, WriteVarintExt}; 400 | 401 | mod fuse_wrapper { 402 | use std::{ 403 | future::Future, 404 | pin::Pin, 405 | result::Result, 406 | task::{Context, Poll}, 407 | }; 408 | 409 | pub struct FusedOneshotReceiver(pub tokio::sync::oneshot::Receiver); 410 | 411 | impl Future for FusedOneshotReceiver { 412 | type Output = Result; 413 | 414 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 415 | if self.0.is_terminated() { 416 | // don't panic when polling a terminated receiver 417 | Poll::Pending 418 | } else { 419 | Future::poll(Pin::new(&mut self.0), cx) 420 | } 421 | } 422 | } 423 | } 424 | pub(crate) use fuse_wrapper::FusedOneshotReceiver; 425 | 426 | #[cfg(feature = "rpc")] 427 | mod now_or_never { 428 | use std::{ 429 | future::Future, 430 | pin::Pin, 431 | task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, 432 | }; 433 | 434 | // Simple pin_mut! macro implementation 435 | macro_rules! pin_mut { 436 | ($($x:ident),* $(,)?) => { 437 | $( 438 | let mut $x = $x; 439 | #[allow(unused_mut)] 440 | let mut $x = unsafe { Pin::new_unchecked(&mut $x) }; 441 | )* 442 | } 443 | } 444 | 445 | // Minimal implementation of a no-op waker 446 | fn noop_waker() -> Waker { 447 | fn noop(_: *const ()) {} 448 | fn clone(_: *const ()) -> RawWaker { 449 | let vtable = &RawWakerVTable::new(clone, noop, noop, noop); 450 | RawWaker::new(std::ptr::null(), vtable) 451 | } 452 | 453 | unsafe { Waker::from_raw(clone(std::ptr::null())) } 454 | } 455 | 456 | /// Attempts to complete a future immediately, returning None if it would block 457 | pub(crate) fn now_or_never(future: F) -> Option { 458 | let waker = noop_waker(); 459 | let mut cx = Context::from_waker(&waker); 460 | 461 | pin_mut!(future); 462 | 463 | match future.poll(&mut cx) { 464 | Poll::Ready(x) => Some(x), 465 | Poll::Pending => None, 466 | } 467 | } 468 | } 469 | #[cfg(feature = "rpc")] 470 | pub(crate) use now_or_never::now_or_never; 471 | -------------------------------------------------------------------------------- /irpc-derive/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashSet; 2 | 3 | use proc_macro::TokenStream; 4 | use proc_macro2::{Span, TokenStream as TokenStream2}; 5 | use quote::{quote, ToTokens}; 6 | use syn::{ 7 | parse::{Parse, ParseStream}, 8 | parse_macro_input, 9 | punctuated::Punctuated, 10 | spanned::Spanned, 11 | token::Comma, 12 | Attribute, Data, DeriveInput, Error, Fields, Ident, LitStr, Token, Type, Visibility, 13 | }; 14 | 15 | /// Attribute on protocol enums and variants 16 | const RPC_ATTR_NAME: &str = "rpc"; 17 | /// Attribute on variants to wrap in generated struct 18 | const WRAP_ATTR_NAME: &str = "wrap"; 19 | /// The tx type name 20 | const TX_ATTR: &str = "tx"; 21 | /// The rx type name 22 | const RX_ATTR: &str = "rx"; 23 | /// Fully qualified path to the default rx type 24 | const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver"; 25 | /// Fully qualified path to the default tx type 26 | const DEFAULT_TX_TYPE: &str = "::irpc::channel::none::NoSender"; 27 | 28 | // See `irpc::rpc_requests` for docs. 29 | #[proc_macro_attribute] 30 | pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { 31 | let mut input = parse_macro_input!(item as DeriveInput); 32 | let args = parse_macro_input!(attr as MacroArgs); 33 | 34 | let enum_name = &input.ident; 35 | let vis = &input.vis; 36 | 37 | let data_enum = match &mut input.data { 38 | Data::Enum(data_enum) => data_enum, 39 | _ => { 40 | return error_tokens( 41 | input.span(), 42 | "The rpc_requests macro can only be applied to enums", 43 | ) 44 | } 45 | }; 46 | 47 | let cfg_feature_rpc = match args.rpc_feature.as_ref() { 48 | None => quote!(), 49 | Some(feature) => quote!(#[cfg(feature = #feature)]), 50 | }; 51 | 52 | // Collect trait implementations 53 | let mut channel_impls = TokenStream2::new(); 54 | // Types to check for uniqueness 55 | let mut types = HashSet::new(); 56 | // All variant names and types 57 | let mut all_variants = Vec::new(); 58 | // Variants with rpc attributes (for From implementations) 59 | let mut variants_with_attr = Vec::new(); 60 | // Wrapper types (via wrap attribute) 61 | let mut wrapper_types = TokenStream2::new(); 62 | 63 | for variant in &mut data_enum.variants { 64 | let rpc_attr = match VariantRpcArgs::from_attrs(&mut variant.attrs) { 65 | Ok(args) => args, 66 | Err(err) => return err.into_compile_error().into(), 67 | }; 68 | 69 | let request_type = match rpc_attr.wrap { 70 | None => match &mut variant.fields { 71 | Fields::Unnamed(ref mut fields) if fields.unnamed.len() == 1 => { 72 | fields.unnamed[0].ty.clone() 73 | } 74 | _ => return error_tokens( 75 | variant.span(), 76 | "Each variant must either have exactly one unnamed field, or use the `wrap` argument in the `rpc` attribute.", 77 | ), 78 | }, 79 | Some(WrapArgs { ident, derive, vis }) => { 80 | let vis = vis.as_ref().unwrap_or(&input.vis).clone(); 81 | let ty = type_from_ident(&ident); 82 | let struc = struct_from_variant_fields(ident, variant.fields.clone(), variant.attrs.clone(), vis); 83 | wrapper_types.extend(quote! { 84 | #[derive(::std::fmt::Debug, ::serde::Serialize, ::serde::Deserialize, #(#derive),* )] 85 | #struc 86 | }); 87 | variant.fields = single_unnamed_field(ty.clone()); 88 | ty 89 | } 90 | }; 91 | 92 | all_variants.push((variant.ident.clone(), request_type.clone())); 93 | 94 | if !types.insert(request_type.to_token_stream().to_string()) { 95 | return error_tokens( 96 | variant.span(), 97 | "Each variant must have a unique request type", 98 | ); 99 | } 100 | 101 | if let Some(args) = rpc_attr.rpc { 102 | variants_with_attr.push((variant.ident.clone(), request_type.clone())); 103 | channel_impls.extend(generate_channels_impl(args, enum_name, &request_type)) 104 | } 105 | } 106 | 107 | // Generate From implementations for the original enum (only for variants with rpc attributes) 108 | let protocol_enum_from_impls = 109 | generate_protocol_enum_from_impls(enum_name, &variants_with_attr); 110 | 111 | // Generate type aliases if requested 112 | let type_aliases = if let Some(suffix) = args.alias_suffix { 113 | // Use all variants for type aliases, not just those with rpc attributes 114 | generate_type_aliases(&all_variants, enum_name, &suffix) 115 | } else { 116 | quote! {} 117 | }; 118 | 119 | // Generate the extended message enum if requested 120 | let extended_enum_code = if let Some(message_enum_name) = args.message_enum_name.as_ref() { 121 | let message_variants = all_variants 122 | .iter() 123 | .map(|(variant_name, inner_type)| { 124 | quote! { 125 | #variant_name(::irpc::WithChannels<#inner_type, #enum_name>) 126 | } 127 | }) 128 | .collect::>(); 129 | 130 | // Extract variant names for the parent_span implementation 131 | let variant_names: Vec<&Ident> = all_variants.iter().map(|(name, _)| name).collect(); 132 | 133 | // Create the message enum definition 134 | let doc = format!("Message enum for [`{enum_name}`]"); 135 | let message_enum = quote! { 136 | #[doc = #doc] 137 | #[allow(missing_docs)] 138 | #[derive(::std::fmt::Debug)] 139 | #vis enum #message_enum_name { 140 | #(#message_variants),* 141 | } 142 | }; 143 | 144 | // Generate parent_span method 145 | let parent_span_impl = if !args.no_spans { 146 | generate_parent_span_impl(message_enum_name, &variant_names) 147 | } else { 148 | quote! {} 149 | }; 150 | 151 | // Generate From implementations for the message enum (only for variants with rpc attributes) 152 | let message_from_impls = 153 | generate_message_enum_from_impls(message_enum_name, &variants_with_attr, enum_name); 154 | 155 | let service_impl = quote! { 156 | impl ::irpc::Service for #enum_name { 157 | type Message = #message_enum_name; 158 | } 159 | }; 160 | 161 | let remote_service_impl = if !args.no_rpc { 162 | let block = 163 | generate_remote_service_impl(message_enum_name, enum_name, &variants_with_attr); 164 | quote! { 165 | #cfg_feature_rpc 166 | #block 167 | } 168 | } else { 169 | quote! {} 170 | }; 171 | 172 | quote! { 173 | #message_enum 174 | #service_impl 175 | #remote_service_impl 176 | #parent_span_impl 177 | #message_from_impls 178 | } 179 | } else { 180 | quote! {} 181 | }; 182 | 183 | // Combine everything 184 | let output = quote! { 185 | #input 186 | 187 | // Wrapper types 188 | #wrapper_types 189 | 190 | // Channel implementations 191 | #channel_impls 192 | 193 | // From implementations for the original enum 194 | #protocol_enum_from_impls 195 | 196 | // Type aliases for WithChannels 197 | #type_aliases 198 | 199 | // Extended enum and its implementations 200 | #extended_enum_code 201 | }; 202 | 203 | output.into() 204 | } 205 | 206 | /// Generate parent span method for an enum 207 | fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> TokenStream2 { 208 | quote! { 209 | impl #enum_name { 210 | /// Get the parent span of the message 211 | pub fn parent_span(&self) -> ::tracing::Span { 212 | let span = match self { 213 | #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),* 214 | }; 215 | span.cloned().unwrap_or_else(|| ::tracing::Span::current()) 216 | } 217 | } 218 | } 219 | } 220 | 221 | fn generate_channels_impl( 222 | args: RpcArgs, 223 | service_name: &Ident, 224 | request_type: &Type, 225 | ) -> TokenStream2 { 226 | let rx = args.rx.unwrap_or_else(|| { 227 | // We can safely unwrap here because this is a known valid type 228 | syn::parse_str::(DEFAULT_RX_TYPE).expect("Failed to parse default rx type") 229 | }); 230 | let tx = args.tx.unwrap_or_else(|| { 231 | // We can safely unwrap here because this is a known valid type 232 | syn::parse_str::(DEFAULT_TX_TYPE).expect("Failed to parse default tx type") 233 | }); 234 | 235 | quote! { 236 | impl ::irpc::Channels<#service_name> for #request_type { 237 | type Tx = #tx; 238 | type Rx = #rx; 239 | } 240 | } 241 | } 242 | 243 | /// Generates `From` impls for protocol enum variants with an rpc attribute. 244 | fn generate_protocol_enum_from_impls( 245 | enum_name: &Ident, 246 | variants_with_attr: &[(Ident, Type)], 247 | ) -> TokenStream2 { 248 | variants_with_attr 249 | .iter() 250 | .map(|(variant_name, inner_type)| { 251 | quote! { 252 | impl From<#inner_type> for #enum_name { 253 | fn from(value: #inner_type) -> Self { 254 | #enum_name::#variant_name(value) 255 | } 256 | } 257 | } 258 | }) 259 | .collect() 260 | } 261 | 262 | /// Generate `From>` impls for message enum variants. 263 | fn generate_message_enum_from_impls( 264 | message_enum_name: &Ident, 265 | variants_with_attr: &[(Ident, Type)], 266 | service_name: &Ident, 267 | ) -> TokenStream2 { 268 | variants_with_attr 269 | .iter() 270 | .map(|(variant_name, inner_type)| { 271 | quote! { 272 | impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name { 273 | fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self { 274 | #message_enum_name::#variant_name(value) 275 | } 276 | } 277 | } 278 | }) 279 | .collect() 280 | } 281 | 282 | /// Generate `RemoteService` impl for message enums. 283 | fn generate_remote_service_impl( 284 | message_enum_name: &Ident, 285 | proto_enum_name: &Ident, 286 | variants_with_attr: &[(Ident, Type)], 287 | ) -> TokenStream2 { 288 | let variants = variants_with_attr 289 | .iter() 290 | .map(|(variant_name, _inner_type)| { 291 | quote! { 292 | #proto_enum_name::#variant_name(msg) => { 293 | #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx))) 294 | } 295 | } 296 | }); 297 | 298 | quote! { 299 | impl ::irpc::rpc::RemoteService for #proto_enum_name { 300 | fn with_remote_channels( 301 | self, 302 | rx: ::irpc::rpc::quinn::RecvStream, 303 | tx: ::irpc::rpc::quinn::SendStream 304 | ) -> Self::Message { 305 | match self { 306 | #(#variants),* 307 | } 308 | } 309 | } 310 | } 311 | } 312 | 313 | /// Generate type aliases for `WithChannels` 314 | fn generate_type_aliases( 315 | variants: &[(Ident, Type)], 316 | service_name: &Ident, 317 | suffix: &str, 318 | ) -> TokenStream2 { 319 | variants 320 | .iter() 321 | .map(|(variant_name, inner_type)| { 322 | // Create a type name using the variant name + suffix 323 | // For example: Sum + "Msg" = SumMsg 324 | let type_name = format!("{variant_name}{suffix}"); 325 | let type_ident = Ident::new(&type_name, variant_name.span()); 326 | quote! { 327 | /// Type alias for WithChannels<#inner_type, #service_name> 328 | pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>; 329 | } 330 | }) 331 | .collect() 332 | } 333 | 334 | // Parse arguments for the macro 335 | #[derive(Default)] 336 | struct MacroArgs { 337 | message_enum_name: Option, 338 | alias_suffix: Option, 339 | rpc_feature: Option, 340 | no_rpc: bool, 341 | no_spans: bool, 342 | } 343 | 344 | impl Parse for MacroArgs { 345 | fn parse(input: ParseStream) -> syn::Result { 346 | let mut this = Self::default(); 347 | loop { 348 | let arg: Ident = input.parse()?; 349 | match arg.to_string().as_str() { 350 | "message" => { 351 | input.parse::()?; 352 | let value: Ident = input.parse()?; 353 | this.message_enum_name = Some(value); 354 | } 355 | "alias" => { 356 | input.parse::()?; 357 | let value: LitStr = input.parse()?; 358 | this.alias_suffix = Some(value.value()); 359 | } 360 | "rpc_feature" => { 361 | input.parse::()?; 362 | if this.no_rpc { 363 | return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc"); 364 | } 365 | let value: LitStr = input.parse()?; 366 | this.rpc_feature = Some(value.value()); 367 | } 368 | "no_rpc" => { 369 | if this.rpc_feature.is_some() { 370 | return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc"); 371 | } 372 | this.no_rpc = true; 373 | } 374 | "no_spans" => { 375 | this.no_spans = true; 376 | } 377 | _ => { 378 | return syn_err(arg.span(), format!("Unknown parameter: {arg}")); 379 | } 380 | } 381 | 382 | if input.peek(Token![,]) { 383 | input.parse::()?; 384 | } else { 385 | break; 386 | } 387 | } 388 | 389 | Ok(this) 390 | } 391 | } 392 | 393 | #[derive(Default)] 394 | struct VariantRpcArgs { 395 | wrap: Option, 396 | rpc: Option, 397 | } 398 | 399 | impl VariantRpcArgs { 400 | fn from_attrs(attrs: &mut Vec) -> syn::Result { 401 | let mut this = Self::default(); 402 | let mut remaining_attrs = Vec::new(); 403 | for attr in attrs.drain(..) { 404 | let ident = attr.path().get_ident().map(|ident| ident.to_string()); 405 | match ident.as_deref() { 406 | Some(RPC_ATTR_NAME) => { 407 | if this.rpc.is_some() { 408 | syn_err(attr.span(), "Each variant can have only one rpc attribute")?; 409 | } 410 | this.rpc = Some(attr.parse_args()?); 411 | } 412 | Some(WRAP_ATTR_NAME) => { 413 | if this.wrap.is_some() { 414 | syn_err(attr.span(), "Each variant can have only one wrap attribute")?; 415 | } 416 | this.wrap = Some(attr.parse_args()?); 417 | } 418 | _ => remaining_attrs.push(attr), 419 | } 420 | } 421 | *attrs = remaining_attrs; 422 | Ok(this) 423 | } 424 | } 425 | 426 | #[derive(Default)] 427 | struct RpcArgs { 428 | rx: Option, 429 | tx: Option, 430 | } 431 | 432 | /// Parse the rpc args as a comma separated list of name=type pairs 433 | impl Parse for RpcArgs { 434 | fn parse(input: ParseStream) -> syn::Result { 435 | let mut this = Self::default(); 436 | while !input.is_empty() { 437 | let arg: Ident = input.parse()?; 438 | let _: Token![=] = input.parse()?; 439 | let value: Type = input.parse()?; 440 | if arg == RX_ATTR { 441 | this.rx = Some(value); 442 | } else if arg == TX_ATTR { 443 | this.tx = Some(value); 444 | } else { 445 | syn_err(arg.span(), "Unexpected argument in rpc attribute")?; 446 | } 447 | if !input.peek(Token![,]) { 448 | break; 449 | } else { 450 | let _: Token![,] = input.parse()?; 451 | } 452 | } 453 | 454 | Ok(this) 455 | } 456 | } 457 | 458 | struct WrapArgs { 459 | vis: Option, 460 | ident: Ident, 461 | derive: Vec, 462 | } 463 | 464 | impl Parse for WrapArgs { 465 | fn parse(input: ParseStream) -> syn::Result { 466 | let vis = match input.parse::()? { 467 | Visibility::Inherited => None, 468 | vis => Some(vis), 469 | }; 470 | let ident: Ident = input.parse()?; 471 | let mut this = Self { 472 | ident, 473 | derive: Default::default(), 474 | vis, 475 | }; 476 | while input.peek(Token![,]) { 477 | let _: Token![,] = input.parse()?; 478 | let arg: Ident = input.parse()?; 479 | match arg.to_string().as_str() { 480 | "derive" => { 481 | let content; 482 | syn::parenthesized!(content in input); 483 | let types: Punctuated = Punctuated::parse_terminated(&content)?; 484 | this.derive = types.into_iter().collect(); 485 | } 486 | _ => syn_err(arg.span(), "Unexpected argument in wrap argument")?, 487 | } 488 | } 489 | if !input.is_empty() { 490 | syn_err(input.span(), "Unexpected tokens in wrap argument")?; 491 | } 492 | Ok(this) 493 | } 494 | } 495 | 496 | fn type_from_ident(ident: &Ident) -> Type { 497 | Type::Path(syn::TypePath { 498 | qself: None, 499 | path: syn::Path { 500 | leading_colon: None, 501 | segments: Punctuated::from_iter([syn::PathSegment::from(ident.clone())]), 502 | }, 503 | }) 504 | } 505 | 506 | fn struct_from_variant_fields( 507 | ident: Ident, 508 | mut fields: Fields, 509 | attrs: Vec, 510 | vis: Visibility, 511 | ) -> syn::ItemStruct { 512 | set_fields_vis(&mut fields, &vis); 513 | let span = ident.span(); 514 | syn::ItemStruct { 515 | attrs, 516 | vis, 517 | struct_token: Token![struct](span), 518 | ident, 519 | generics: Default::default(), 520 | semi_token: match &fields { 521 | Fields::Unit => Some(Token![;](span)), 522 | Fields::Unnamed(_) => Some(Token![;](span)), 523 | Fields::Named(_) => None, 524 | }, 525 | fields, 526 | } 527 | } 528 | 529 | fn single_unnamed_field(ty: Type) -> Fields { 530 | let field = syn::Field { 531 | attrs: vec![], 532 | vis: Visibility::Inherited, 533 | ident: None, 534 | colon_token: None, 535 | mutability: syn::FieldMutability::None, 536 | ty, 537 | }; 538 | Fields::Unnamed(syn::FieldsUnnamed { 539 | paren_token: syn::token::Paren(Span::call_site()), 540 | unnamed: Punctuated::from_iter([field]), 541 | }) 542 | } 543 | 544 | fn set_fields_vis(fields: &mut Fields, vis: &Visibility) { 545 | let inner = match fields { 546 | Fields::Named(ref mut named) => named.named.iter_mut(), 547 | Fields::Unnamed(ref mut unnamed) => unnamed.unnamed.iter_mut(), 548 | Fields::Unit => return, 549 | }; 550 | for field in inner { 551 | field.vis = vis.clone(); 552 | } 553 | } 554 | 555 | // Helper function for error reporting 556 | fn error_tokens(span: Span, message: &str) -> TokenStream { 557 | Error::new(span, message).to_compile_error().into() 558 | } 559 | 560 | fn syn_err(span: Span, message: impl std::fmt::Display) -> syn::Result { 561 | Err(Error::new(span, message)) 562 | } 563 | --------------------------------------------------------------------------------