├── .github ├── dependabot.yml └── workflows │ ├── main.yml │ └── pr_review.yml ├── .gitignore ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE ├── README.md ├── RELEASES.md ├── example-service ├── Cargo.toml ├── README.md └── src │ ├── client.rs │ ├── lib.rs │ └── server.rs ├── hooks ├── pre-commit └── pre-push ├── plugins ├── Cargo.toml ├── LICENSE ├── rustfmt.toml ├── src │ └── lib.rs └── tests │ ├── server.rs │ └── service.rs └── tarpc ├── Cargo.toml ├── LICENSE ├── README.md ├── clippy.toml ├── examples ├── certs │ └── eddsa │ │ ├── client.cert │ │ ├── client.chain │ │ ├── client.key │ │ ├── end.cert │ │ ├── end.chain │ │ └── end.key ├── compression.rs ├── custom_transport.rs ├── pubsub.rs ├── readme.rs ├── tls_over_tcp.rs └── tracing.rs ├── rustfmt.toml ├── src ├── cancellations.rs ├── client.rs ├── client │ ├── in_flight_requests.rs │ ├── stub.rs │ └── stub │ │ ├── load_balance.rs │ │ ├── mock.rs │ │ └── retry.rs ├── context.rs ├── lib.rs ├── serde_transport.rs ├── server.rs ├── server │ ├── in_flight_requests.rs │ ├── incoming.rs │ ├── limits.rs │ ├── limits │ │ ├── channels_per_key.rs │ │ └── requests_per_channel.rs │ ├── request_hook.rs │ ├── request_hook │ │ ├── after.rs │ │ ├── before.rs │ │ └── before_and_after.rs │ └── testing.rs ├── trace.rs ├── transport.rs ├── transport │ └── channel.rs ├── util.rs └── util │ └── serde.rs └── tests ├── compile_fail.rs ├── compile_fail ├── must_use_request_dispatch.rs ├── must_use_request_dispatch.stderr ├── no_serde1 │ ├── no_explicit_serde_without_feature.rs │ ├── no_explicit_serde_without_feature.stderr │ ├── no_implicit_serde_without_feature.rs │ └── no_implicit_serde_without_feature.stderr ├── serde1 │ ├── deprecated.rs │ ├── deprecated.stderr │ ├── incompatible.rs │ ├── incompatible.stderr │ ├── opt_out_serde.rs │ └── opt_out_serde.stderr ├── serde_transport │ ├── must_use_tcp_connect.rs │ └── must_use_tcp_connect.stderr ├── tarpc_service_arg_pat.rs ├── tarpc_service_arg_pat.stderr ├── tarpc_service_derive_serde.rs ├── tarpc_service_derive_serde.stderr ├── tarpc_service_fn_new.rs ├── tarpc_service_fn_new.stderr ├── tarpc_service_fn_serve.rs └── tarpc_service_fn_serve.stderr ├── dataservice.rs ├── proc_macro_hygene.rs └── service_functional.rs /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Please see the documentation for all configuration options: 2 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 3 | # 4 | version: 2 5 | updates: 6 | - package-ecosystem: "cargo" 7 | directory: "/" 8 | schedule: 9 | interval: "weekly" 10 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | pull_request: 6 | branches: 7 | - master 8 | merge_group: 9 | branches: 10 | - master 11 | 12 | name: Continuous Integration 13 | 14 | concurrency: 15 | group: "${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}" 16 | cancel-in-progress: "${{ github.ref != 'refs/heads/master' }}" 17 | 18 | jobs: 19 | test: 20 | name: Test 21 | runs-on: ubuntu-latest 22 | strategy: 23 | matrix: 24 | serde: ["--features serde1", ""] 25 | tokio: ["--features tokio1", ""] 26 | serde-transport: ["--features serde-transport", ""] 27 | serde-transport-json: ["--features serde-transport-json", ""] 28 | serde-transport-bincode: ["--features serde-transport-bincode", ""] 29 | tcp: ["--features tcp", ""] 30 | unix: ["--features unix", ""] 31 | exclude: 32 | - serde-transport-json: "--features serde-transport-json" 33 | serde-transport: "" 34 | - serde-transport-bincode: "--features serde-transport-bincode" 35 | serde-transport: "" 36 | - serde-transport: "--features serde-transport" 37 | tokio: "" 38 | - serde-transport: "--features serde-transport" 39 | serde: "" 40 | - tcp: "--features tcp" 41 | serde-transport: "" 42 | - unix: "--features unix" 43 | serde-transport: "" 44 | steps: 45 | - uses: actions/checkout@v4 46 | - uses: dtolnay/rust-toolchain@stable 47 | - run: > 48 | cargo test --manifest-path tarpc/Cargo.toml 49 | ${{ matrix.serde }} ${{ matrix.tokio }} ${{ matrix.serde-transport }} 50 | ${{ matrix.serde-transport-json }} ${{ matrix.serde-transport-bincode }} 51 | ${{ matrix.tcp }} ${{ matrix.unix }} 52 | 53 | list-examples: 54 | name: List Examples 55 | runs-on: ubuntu-latest 56 | outputs: 57 | examples: ${{ steps.matrix.outputs.examples }} 58 | steps: 59 | - uses: actions/checkout@v4 60 | - uses: dtolnay/rust-toolchain@stable 61 | - id: matrix 62 | run: | 63 | examples=$( 64 | cargo metadata --no-deps --format-version=1 \ 65 | | jq '.packages[] 66 | | select ( .name == "tarpc" ) 67 | | .targets[] 68 | | select (.kind[] | . == "example") 69 | | .name' \ 70 | | jq -s -c '.' 71 | ) 72 | echo "examples=$examples" | tee -a $GITHUB_OUTPUT 73 | 74 | run-example: 75 | name: Run Example 76 | needs: list-examples 77 | runs-on: ubuntu-latest 78 | strategy: 79 | matrix: 80 | example: ${{ fromJSON(needs.list-examples.outputs.examples) }} 81 | steps: 82 | - uses: actions/checkout@v4 83 | - uses: dtolnay/rust-toolchain@stable 84 | - run: | 85 | cargo run --example "${{ matrix.example }}" 86 | 87 | fmt: 88 | name: Rustfmt 89 | runs-on: ubuntu-latest 90 | steps: 91 | - uses: actions/checkout@v4 92 | - uses: dtolnay/rust-toolchain@stable 93 | with: 94 | components: rustfmt 95 | - run: cargo fmt --all -- --check 96 | 97 | clippy: 98 | name: Clippy 99 | runs-on: ubuntu-latest 100 | steps: 101 | - uses: actions/checkout@v4 102 | - uses: dtolnay/rust-toolchain@stable 103 | with: 104 | components: clippy 105 | - run: cargo clippy --all-features --all-targets -- -D warnings 106 | 107 | # This job succeeds if all other tests and examples succeed. Otherwise, it fails. It is for use in 108 | # branch protection rules. 109 | test-suite: 110 | name: Test Suite 111 | runs-on: ubuntu-latest 112 | needs: [test, run-example, fmt, clippy] 113 | if: always() 114 | steps: 115 | - name: All tests ok 116 | if: ${{ !(contains(needs.*.result, 'failure')) }} 117 | run: exit 0 118 | - name: Some tests failed 119 | if: ${{ contains(needs.*.result, 'failure') }} 120 | run: exit 1 121 | -------------------------------------------------------------------------------- /.github/workflows/pr_review.yml: -------------------------------------------------------------------------------- 1 | name: PR Review 2 | on: [pull_request] 3 | jobs: 4 | clippy: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v4 8 | - uses: dtolnay/rust-toolchain@stable 9 | with: 10 | components: clippy 11 | - uses: giraffate/clippy-action@v1 12 | with: 13 | reporter: 'github-pr-review' 14 | github_token: ${{ github.token }} 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | Cargo.lock 3 | .cargo 4 | *.swp 5 | *.bk 6 | tarpc.iml 7 | .idea 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Want to contribute? Great! First, read this page (including the small print at the end). 2 | 3 | ### Before you contribute 4 | Before we can use your code, you must sign the 5 | [Google Individual Contributor License Agreement] 6 | (https://cla.developers.google.com/about/google-individual) 7 | (CLA), which you can do online. The CLA is necessary mainly because you own the 8 | copyright to your changes, even after your contribution becomes part of our 9 | codebase, so we need your permission to use and distribute your code. We also 10 | need to be sure of various other things—for instance that you'll tell us if you 11 | know that your code infringes on other people's patents. You don't have to sign 12 | the CLA until after you've submitted your code for review and a member has 13 | approved it, but you must do it before we can put your code into our codebase. 14 | Before you start working on a larger contribution, you should get in touch with 15 | us first through the issue tracker with your idea so that we can help out and 16 | possibly guide you. Coordinating up front makes it much easier to avoid 17 | frustration later on. 18 | 19 | ### Code reviews 20 | All submissions, including submissions by project members, require review. We 21 | use Github pull requests for this purpose. 22 | 23 | ### The small print 24 | Contributions made by corporations are covered by a different agreement than 25 | the one above, the 26 | [Software Grant and Corporate Contributor License Agreement] 27 | (https://cla.developers.google.com/about/google-corporate). 28 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | 4 | members = [ 5 | "example-service", 6 | "tarpc", 7 | "plugins", 8 | ] 9 | 10 | [profile.dev] 11 | split-debuginfo = "unpacked" 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright 2016 Google Inc. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Crates.io][crates-badge]][crates-url] 2 | [![MIT licensed][mit-badge]][mit-url] 3 | [![Build status][gh-actions-badge]][gh-actions-url] 4 | [![Discord chat][discord-badge]][discord-url] 5 | 6 | [crates-badge]: https://img.shields.io/crates/v/tarpc.svg 7 | [crates-url]: https://crates.io/crates/tarpc 8 | [mit-badge]: https://img.shields.io/badge/license-MIT-blue.svg 9 | [mit-url]: LICENSE 10 | [gh-actions-badge]: https://github.com/google/tarpc/workflows/Continuous%20Integration/badge.svg 11 | [gh-actions-url]: https://github.com/google/tarpc/actions?query=workflow%3A%22Continuous+Integration%22 12 | [discord-badge]: https://img.shields.io/discord/647529123996237854.svg?logo=discord&style=flat-square 13 | [discord-url]: https://discord.gg/gXwpdSt 14 | 15 | # tarpc 16 | 17 | 18 | 19 | *Disclaimer*: This is not an official Google product. 20 | 21 | tarpc is an RPC framework for rust with a focus on ease of use. Defining a 22 | service can be done in just a few lines of code, and most of the boilerplate of 23 | writing a server is taken care of for you. 24 | 25 | [Documentation](https://docs.rs/crate/tarpc/) 26 | 27 | ## What is an RPC framework? 28 | "RPC" stands for "Remote Procedure Call," a function call where the work of 29 | producing the return value is being done somewhere else. When an rpc function is 30 | invoked, behind the scenes the function contacts some other process somewhere 31 | and asks them to evaluate the function instead. The original function then 32 | returns the value produced by the other process. 33 | 34 | RPC frameworks are a fundamental building block of most microservices-oriented 35 | architectures. Two well-known ones are [gRPC](http://www.grpc.io) and 36 | [Cap'n Proto](https://capnproto.org/). 37 | 38 | tarpc differentiates itself from other RPC frameworks by defining the schema in code, 39 | rather than in a separate language such as .proto. This means there's no separate compilation 40 | process, and no context switching between different languages. 41 | 42 | Some other features of tarpc: 43 | - Pluggable transport: any type implementing `Stream + Sink` can be 44 | used as a transport to connect the client and server. 45 | - `Send + 'static` optional: if the transport doesn't require it, neither does tarpc! 46 | - Cascading cancellation: dropping a request will send a cancellation message to the server. 47 | The server will cease any unfinished work on the request, subsequently cancelling any of its 48 | own requests, repeating for the entire chain of transitive dependencies. 49 | - Configurable deadlines and deadline propagation: request deadlines default to 10s if 50 | unspecified. The server will automatically cease work when the deadline has passed. Any 51 | requests sent by the server that use the request context will propagate the request deadline. 52 | For example, if a server is handling a request with a 10s deadline, does 2s of work, then 53 | sends a request to another server, that server will see an 8s deadline. 54 | - Distributed tracing: tarpc is instrumented with 55 | [tracing](https://github.com/tokio-rs/tracing) primitives extended with 56 | [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible tracing subscriber like 57 | [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger), 58 | each RPC can be traced through the client, server, and other dependencies downstream of the 59 | server. Even for applications not connected to a distributed tracing collector, the 60 | instrumentation can also be ingested by regular loggers like 61 | [env_logger](https://github.com/env-logger-rs/env_logger/). 62 | - Serde serialization: enabling the `serde1` Cargo feature will make service requests and 63 | responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can 64 | be used, as well, so the price of serialization doesn't have to be paid when it's not needed. 65 | 66 | ## Usage 67 | Add to your `Cargo.toml` dependencies: 68 | 69 | ```toml 70 | tarpc = "0.36" 71 | ``` 72 | 73 | The `tarpc::service` attribute expands to a collection of items that form an rpc service. 74 | These generated types make it easy and ergonomic to write servers with less boilerplate. 75 | Simply implement the generated service trait, and you're off to the races! 76 | 77 | ## Example 78 | 79 | This example uses [tokio](https://tokio.rs), so add the following dependencies to 80 | your `Cargo.toml`: 81 | 82 | ```toml 83 | anyhow = "1.0" 84 | futures = "0.3" 85 | tarpc = { version = "0.36", features = ["tokio1"] } 86 | tokio = { version = "1.0", features = ["rt-multi-thread", "macros"] } 87 | ``` 88 | 89 | In the following example, we use an in-process channel for communication between 90 | client and server. In real code, you will likely communicate over the network. 91 | For a more real-world example, see [example-service](example-service). 92 | 93 | First, let's set up the dependencies and service definition. 94 | 95 | ```rust 96 | use futures::prelude::*; 97 | use tarpc::{ 98 | client, context, 99 | server::{self, Channel}, 100 | }; 101 | 102 | // This is the service definition. It looks a lot like a trait definition. 103 | // It defines one RPC, hello, which takes one arg, name, and returns a String. 104 | #[tarpc::service] 105 | trait World { 106 | /// Returns a greeting for name. 107 | async fn hello(name: String) -> String; 108 | } 109 | ``` 110 | 111 | This service definition generates a trait called `World`. Next we need to 112 | implement it for our Server struct. 113 | 114 | ```rust 115 | // This is the type that implements the generated World trait. It is the business logic 116 | // and is used to start the server. 117 | #[derive(Clone)] 118 | struct HelloServer; 119 | 120 | impl World for HelloServer { 121 | async fn hello(self, _: context::Context, name: String) -> String { 122 | format!("Hello, {name}!") 123 | } 124 | } 125 | ``` 126 | 127 | Lastly let's write our `main` that will start the server. While this example uses an 128 | [in-process channel](transport::channel), tarpc also ships a generic [`serde_transport`] 129 | behind the `serde-transport` feature, with additional [TCP](serde_transport::tcp) functionality 130 | available behind the `tcp` feature. 131 | 132 | ```rust 133 | #[tokio::main] 134 | async fn main() -> anyhow::Result<()> { 135 | let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); 136 | 137 | let server = server::BaseChannel::with_defaults(server_transport); 138 | tokio::spawn( 139 | server.execute(HelloServer.serve()) 140 | // Handle all requests concurrently. 141 | .for_each(|response| async move { 142 | tokio::spawn(response); 143 | })); 144 | 145 | // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` 146 | // that takes a config and any Transport as input. 147 | let mut client = WorldClient::new(client::Config::default(), client_transport).spawn(); 148 | 149 | // The client has an RPC method for each RPC defined in the annotated trait. It takes the same 150 | // args as defined, with the addition of a Context, which is always the first arg. The Context 151 | // specifies a deadline and trace information which can be helpful in debugging requests. 152 | let hello = client.hello(context::current(), "Stim".to_string()).await?; 153 | 154 | println!("{hello}"); 155 | 156 | Ok(()) 157 | } 158 | ``` 159 | 160 | ## Service Documentation 161 | 162 | Use `cargo doc` as you normally would to see the documentation created for all 163 | items expanded by a `service!` invocation. 164 | 165 | 166 | 167 | License: MIT 168 | -------------------------------------------------------------------------------- /example-service/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tarpc-example-service" 3 | version = "0.16.1" 4 | rust-version = "1.65.0" 5 | authors = ["Tim Kuehn "] 6 | edition = "2021" 7 | license = "MIT" 8 | documentation = "https://docs.rs/tarpc-example-service" 9 | homepage = "https://github.com/google/tarpc" 10 | repository = "https://github.com/google/tarpc" 11 | keywords = ["rpc", "network", "server", "microservices", "example"] 12 | categories = ["asynchronous", "network-programming"] 13 | readme = "README.md" 14 | description = "An example server built on tarpc." 15 | 16 | [dependencies] 17 | anyhow = "1.0" 18 | clap = { version = "4.4.18", features = ["derive"] } 19 | log = "0.4" 20 | futures = "0.3" 21 | opentelemetry = { version = "0.26.0" } 22 | opentelemetry-otlp = "0.26.0" 23 | rand = "0.8" 24 | tarpc = { version = "0.36", path = "../tarpc", features = ["full"] } 25 | tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } 26 | tracing = { version = "0.1" } 27 | tracing-opentelemetry = "0.27.0" 28 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 29 | opentelemetry_sdk = { version = "0.26.0", features = ["rt-tokio"] } 30 | opentelemetry-semantic-conventions = "0.16.0" 31 | 32 | [lib] 33 | name = "service" 34 | path = "src/lib.rs" 35 | 36 | [[bin]] 37 | name = "server" 38 | path = "src/server.rs" 39 | 40 | [[bin]] 41 | name = "client" 42 | path = "src/client.rs" 43 | -------------------------------------------------------------------------------- /example-service/README.md: -------------------------------------------------------------------------------- 1 | # Example 2 | 3 | Example service to demonstrate how to set up `tarpc` with [Jaeger](https://www.jaegertracing.io) using OTLP. To see traces Jaeger, run the following with `RUST_LOG=trace`. 4 | 5 | ## Server 6 | 7 | ```bash 8 | cargo run --bin server -- --port 50051 9 | ``` 10 | 11 | ## Client 12 | 13 | ```bash 14 | cargo run --bin client -- --server-addr "[::1]:50051" --name "Bob" 15 | ``` 16 | -------------------------------------------------------------------------------- /example-service/src/client.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use clap::Parser; 8 | use service::{init_tracing, WorldClient}; 9 | use std::{net::SocketAddr, time::Duration}; 10 | use tarpc::{client, context, tokio_serde::formats::Json}; 11 | use tokio::time::sleep; 12 | use tracing::Instrument; 13 | 14 | #[derive(Parser)] 15 | struct Flags { 16 | /// Sets the server address to connect to. 17 | #[clap(long)] 18 | server_addr: SocketAddr, 19 | /// Sets the name to say hello to. 20 | #[clap(long)] 21 | name: String, 22 | } 23 | 24 | #[tokio::main] 25 | async fn main() -> anyhow::Result<()> { 26 | let flags = Flags::parse(); 27 | init_tracing("Tarpc Example Client")?; 28 | 29 | let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); 30 | transport.config_mut().max_frame_length(usize::MAX); 31 | 32 | // WorldClient is generated by the service attribute. It has a constructor `new` that takes a 33 | // config and any Transport as input. 34 | let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); 35 | 36 | let hello = async move { 37 | // Send the request twice, just to be safe! ;) 38 | tokio::select! { 39 | hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 } 40 | hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 } 41 | } 42 | } 43 | .instrument(tracing::info_span!("Two Hellos")) 44 | .await; 45 | 46 | match hello { 47 | Ok(hello) => tracing::info!("{hello:?}"), 48 | Err(e) => tracing::warn!("{:?}", anyhow::Error::from(e)), 49 | } 50 | 51 | // Let the background span processor finish. 52 | sleep(Duration::from_micros(1)).await; 53 | opentelemetry::global::shutdown_tracer_provider(); 54 | 55 | Ok(()) 56 | } 57 | -------------------------------------------------------------------------------- /example-service/src/lib.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use opentelemetry::trace::TracerProvider as _; 8 | use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; 9 | 10 | /// This is the service definition. It looks a lot like a trait definition. 11 | /// It defines one RPC, hello, which takes one arg, name, and returns a String. 12 | #[tarpc::service] 13 | pub trait World { 14 | /// Returns a greeting for name. 15 | async fn hello(name: String) -> String; 16 | } 17 | 18 | /// Initializes an OpenTelemetry tracing subscriber with a OTLP backend. 19 | pub fn init_tracing(service_name: &'static str) -> anyhow::Result<()> { 20 | let tracer_provider = opentelemetry_otlp::new_pipeline() 21 | .tracing() 22 | .with_trace_config(opentelemetry_sdk::trace::Config::default().with_resource( 23 | opentelemetry_sdk::Resource::new([opentelemetry::KeyValue::new( 24 | opentelemetry_semantic_conventions::resource::SERVICE_NAME, 25 | service_name, 26 | )]), 27 | )) 28 | .with_batch_config(opentelemetry_sdk::trace::BatchConfig::default()) 29 | .with_exporter(opentelemetry_otlp::new_exporter().tonic()) 30 | .install_batch(opentelemetry_sdk::runtime::Tokio)?; 31 | opentelemetry::global::set_tracer_provider(tracer_provider.clone()); 32 | let tracer = tracer_provider.tracer(service_name); 33 | 34 | tracing_subscriber::registry() 35 | .with(tracing_subscriber::EnvFilter::from_default_env()) 36 | .with(tracing_subscriber::fmt::layer().with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)) 37 | .with(tracing_opentelemetry::layer().with_tracer(tracer)) 38 | .try_init()?; 39 | 40 | Ok(()) 41 | } 42 | -------------------------------------------------------------------------------- /example-service/src/server.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use clap::Parser; 8 | use futures::{future, prelude::*}; 9 | use rand::{ 10 | distributions::{Distribution, Uniform}, 11 | thread_rng, 12 | }; 13 | use service::{init_tracing, World}; 14 | use std::{ 15 | net::{IpAddr, Ipv6Addr, SocketAddr}, 16 | time::Duration, 17 | }; 18 | use tarpc::{ 19 | context, 20 | server::{self, incoming::Incoming, Channel}, 21 | tokio_serde::formats::Json, 22 | }; 23 | use tokio::time; 24 | 25 | #[derive(Parser)] 26 | struct Flags { 27 | /// Sets the port number to listen on. 28 | #[clap(long)] 29 | port: u16, 30 | } 31 | 32 | // This is the type that implements the generated World trait. It is the business logic 33 | // and is used to start the server. 34 | #[derive(Clone)] 35 | struct HelloServer(SocketAddr); 36 | 37 | impl World for HelloServer { 38 | async fn hello(self, _: context::Context, name: String) -> String { 39 | let sleep_time = 40 | Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); 41 | time::sleep(sleep_time).await; 42 | format!("Hello, {name}! You are connected from {}", self.0) 43 | } 44 | } 45 | 46 | async fn spawn(fut: impl Future + Send + 'static) { 47 | tokio::spawn(fut); 48 | } 49 | 50 | #[tokio::main] 51 | async fn main() -> anyhow::Result<()> { 52 | let flags = Flags::parse(); 53 | init_tracing("Tarpc Example Server")?; 54 | 55 | let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), flags.port); 56 | 57 | // JSON transport is provided by the json_transport tarpc module. It makes it easy 58 | // to start up a serde-powered json serialization strategy over TCP. 59 | let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?; 60 | tracing::info!("Listening on port {}", listener.local_addr().port()); 61 | listener.config_mut().max_frame_length(usize::MAX); 62 | listener 63 | // Ignore accept errors. 64 | .filter_map(|r| future::ready(r.ok())) 65 | .map(server::BaseChannel::with_defaults) 66 | // Limit channels to 1 per IP. 67 | .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) 68 | // serve is generated by the service attribute. It takes as input any type implementing 69 | // the generated World trait. 70 | .map(|channel| { 71 | let server = HelloServer(channel.transport().peer_addr().unwrap()); 72 | channel.execute(server.serve()).for_each(spawn) 73 | }) 74 | // Max 10 channels. 75 | .buffer_unordered(10) 76 | .for_each(|_| async {}) 77 | .await; 78 | 79 | Ok(()) 80 | } 81 | -------------------------------------------------------------------------------- /hooks/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2016 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the MIT License, . 6 | # This file may not be copied, modified, or distributed except according to those terms. 7 | 8 | # 9 | # Pre-commit hook for the tarpc repository. To use this hook, copy it to .git/hooks in your 10 | # repository root. 11 | # 12 | # This precommit checks the following: 13 | # 1. All filenames are ascii 14 | # 2. There is no bad whitespace 15 | # 3. rustfmt is installed 16 | # 4. rustfmt is a noop on files that are in the index 17 | # 18 | # Options: 19 | # 20 | # - TARPC_SKIP_RUSTFMT, default = 0 21 | # 22 | # Set this to 1 to skip running rustfmt 23 | # 24 | # Note that these options are most useful for testing the hooks themselves. Use git commit 25 | # --no-verify to skip the pre-commit hook altogether. 26 | 27 | RED='\033[0;31m' 28 | GREEN='\033[0;32m' 29 | YELLOW='\033[0;33m' 30 | NC='\033[0m' # No Color 31 | 32 | PREFIX="${GREEN}[PRECOMMIT]${NC}" 33 | FAILURE="${RED}FAILED${NC}" 34 | WARNING="${RED}[WARNING]${NC}" 35 | SKIPPED="${YELLOW}SKIPPED${NC}" 36 | SUCCESS="${GREEN}ok${NC}" 37 | 38 | if git rev-parse --verify HEAD &>/dev/null 39 | then 40 | against=HEAD 41 | else 42 | # Initial commit: diff against an empty tree object 43 | against=4b825dc642cb6eb9a060e54bf8d69288fbee4904 44 | fi 45 | 46 | FAILED=0 47 | 48 | printf "${PREFIX} Checking that all filenames are ascii ... " 49 | # Note that the use of brackets around a tr range is ok here, (it's 50 | # even required, for portability to Solaris 10's /usr/bin/tr), since 51 | # the square bracket bytes happen to fall in the designated range. 52 | if test $(git diff --cached --name-only --diff-filter=A -z $against | LC_ALL=C tr -d '[ -~]\0' | wc -c) != 0 53 | then 54 | FAILED=1 55 | printf "${FAILURE}\n" 56 | else 57 | printf "${SUCCESS}\n" 58 | fi 59 | 60 | printf "${PREFIX} Checking for bad whitespace ... " 61 | git diff-index --check --cached $against -- &>/dev/null 62 | if [ "$?" != 0 ]; then 63 | FAILED=1 64 | printf "${FAILURE}\n" 65 | else 66 | printf "${SUCCESS}\n" 67 | fi 68 | 69 | printf "${PREFIX} Checking for rustfmt ... " 70 | command -v rustfmt &>/dev/null 71 | if [ $? == 0 ]; then 72 | printf "${SUCCESS}\n" 73 | else 74 | printf "${FAILURE}\n" 75 | exit 1 76 | fi 77 | 78 | printf "${PREFIX} Checking for shasum ... " 79 | command -v shasum &>/dev/null 80 | if [ $? == 0 ]; then 81 | printf "${SUCCESS}\n" 82 | else 83 | printf "${FAILURE}\n" 84 | exit 1 85 | fi 86 | 87 | # Just check that running rustfmt doesn't do anything to the file. I do this instead of 88 | # modifying the file because I don't want to mess with the developer's index, which may 89 | # not only contain discrete files. 90 | printf "${PREFIX} Checking formatting ... " 91 | FMTRESULT=0 92 | diff="" 93 | for file in $(git diff --name-only --cached); 94 | do 95 | if [ ${file: -3} == ".rs" ]; then 96 | diff="$diff$(rustfmt --edition 2018 --check $file)" 97 | if [ $? != 0 ]; then 98 | FMTRESULT=1 99 | fi 100 | fi 101 | done 102 | 103 | if [ "${TARPC_SKIP_RUSTFMT}" == 1 ]; then 104 | printf "${SKIPPED}\n"$? 105 | elif [ ${FMTRESULT} != 0 ]; then 106 | FAILED=1 107 | printf "${FAILURE}\n" 108 | echo "$diff" 109 | else 110 | printf "${SUCCESS}\n" 111 | fi 112 | 113 | exit ${FAILED} 114 | -------------------------------------------------------------------------------- /hooks/pre-push: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2016 Google Inc. All Rights Reserved. 4 | # 5 | # Licensed under the MIT License, . 6 | # This file may not be copied, modified, or distributed except according to those terms. 7 | 8 | # Pre-push hook for the tarpc repository. To use this hook, copy it to .git/hooks in your repository 9 | # root. 10 | # 11 | # This hook ensures the working copy does not contain uncommitted changes, as it is a common error 12 | # to test locally using a dirty working copy without realizing the tests are using a dirty working 13 | # copy. 14 | # 15 | # Use git push --no-verify to skip the pre-push hook altogether. 16 | 17 | RED='\033[0;31m' 18 | GREEN='\033[0;32m' 19 | YELLOW='\033[0;33m' 20 | NC='\033[0m' # No Color 21 | 22 | PREFIX="${GREEN}[PREPUSH]${NC}" 23 | FAILURE="${RED}FAILED${NC}" 24 | SKIPPED="${YELLOW}SKIPPED${NC}" 25 | SUCCESS="${GREEN}ok${NC}" 26 | 27 | printf "${PREFIX} Clean working copy ... " 28 | git diff --exit-code &>/dev/null 29 | if [ "$?" == 0 ]; then 30 | printf "${SUCCESS}\n" 31 | else 32 | printf "${FAILURE}\n" 33 | exit 1 34 | fi 35 | 36 | exit 0 37 | -------------------------------------------------------------------------------- /plugins/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tarpc-plugins" 3 | version = "0.14.1" 4 | rust-version = "1.65.0" 5 | authors = ["Adam Wright ", "Tim Kuehn "] 6 | edition = "2021" 7 | license = "MIT" 8 | documentation = "https://docs.rs/tarpc-plugins" 9 | homepage = "https://github.com/google/tarpc" 10 | repository = "https://github.com/google/tarpc" 11 | keywords = ["rpc", "network", "server", "api", "microservices"] 12 | categories = ["asynchronous", "network-programming"] 13 | readme = "../README.md" 14 | description = "Proc macros for tarpc." 15 | 16 | [features] 17 | serde1 = [] 18 | 19 | [badges] 20 | travis-ci = { repository = "google/tarpc" } 21 | 22 | [dependencies] 23 | proc-macro2 = "1.0" 24 | quote = "1.0" 25 | syn = { version = "2.0", features = ["full", "extra-traits"] } 26 | 27 | [lib] 28 | proc-macro = true 29 | 30 | [dev-dependencies] 31 | assert-type-eq = "0.1.0" 32 | futures = "0.3" 33 | serde = { version = "1.0", features = ["derive"] } 34 | tarpc = { path = "../tarpc", features = ["serde1"] } 35 | -------------------------------------------------------------------------------- /plugins/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright 2016 Google Inc. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /plugins/rustfmt.toml: -------------------------------------------------------------------------------- 1 | edition = "2018" 2 | -------------------------------------------------------------------------------- /plugins/tests/server.rs: -------------------------------------------------------------------------------- 1 | // these need to be out here rather than inside the function so that the 2 | // assert_type_eq macro can pick them up. 3 | #[tarpc::service] 4 | trait Foo { 5 | async fn two_part(s: String, i: i32) -> (String, i32); 6 | async fn bar(s: String) -> String; 7 | async fn baz(); 8 | } 9 | 10 | #[allow(non_camel_case_types)] 11 | #[test] 12 | fn raw_idents_work() { 13 | type r#yield = String; 14 | 15 | #[tarpc::service] 16 | trait r#trait { 17 | async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32); 18 | async fn r#fn(r#impl: r#yield) -> r#yield; 19 | async fn r#async(); 20 | } 21 | } 22 | 23 | #[test] 24 | fn syntax() { 25 | #[tarpc::service] 26 | trait Syntax { 27 | #[deny(warnings)] 28 | #[allow(non_snake_case)] 29 | async fn TestCamelCaseDoesntConflict(); 30 | async fn hello() -> String; 31 | #[doc = "attr"] 32 | async fn attr(s: String) -> String; 33 | async fn no_args_no_return(); 34 | async fn no_args() -> (); 35 | async fn one_arg(one: String) -> i32; 36 | async fn two_args_no_return(one: String, two: u64); 37 | async fn two_args(one: String, two: u64) -> String; 38 | async fn no_args_ret_error() -> i32; 39 | async fn one_arg_ret_error(one: String) -> String; 40 | async fn no_arg_implicit_return_error(); 41 | #[doc = "attr"] 42 | async fn one_arg_implicit_return_error(one: String); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /plugins/tests/service.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use std::hash::Hash; 3 | use tarpc::context; 4 | 5 | #[test] 6 | fn att_service_trait() { 7 | #[tarpc::service] 8 | trait Foo { 9 | async fn two_part(s: String, i: i32) -> (String, i32); 10 | async fn bar(s: String) -> String; 11 | async fn baz(); 12 | } 13 | 14 | impl Foo for () { 15 | async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { 16 | (s, i) 17 | } 18 | 19 | async fn bar(self, _: context::Context, s: String) -> String { 20 | s 21 | } 22 | 23 | async fn baz(self, _: context::Context) {} 24 | } 25 | } 26 | 27 | #[allow(non_camel_case_types)] 28 | #[test] 29 | fn raw_idents() { 30 | type r#yield = String; 31 | 32 | #[tarpc::service] 33 | trait r#trait { 34 | async fn r#await(r#struct: r#yield, r#enum: i32) -> (r#yield, i32); 35 | async fn r#fn(r#impl: r#yield) -> r#yield; 36 | async fn r#async(); 37 | } 38 | 39 | impl r#trait for () { 40 | async fn r#await( 41 | self, 42 | _: context::Context, 43 | r#struct: r#yield, 44 | r#enum: i32, 45 | ) -> (r#yield, i32) { 46 | (r#struct, r#enum) 47 | } 48 | 49 | async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { 50 | r#impl 51 | } 52 | 53 | async fn r#async(self, _: context::Context) {} 54 | } 55 | } 56 | 57 | #[test] 58 | fn service_with_cfg_rpc() { 59 | #[tarpc::service] 60 | trait Foo { 61 | async fn foo(); 62 | #[cfg(not(test))] 63 | async fn bar(s: String) -> String; 64 | } 65 | 66 | impl Foo for () { 67 | async fn foo(self, _: context::Context) {} 68 | } 69 | } 70 | 71 | #[test] 72 | fn syntax() { 73 | #[tarpc::service] 74 | trait Syntax { 75 | #[deny(warnings)] 76 | #[allow(non_snake_case)] 77 | async fn TestCamelCaseDoesntConflict(); 78 | async fn hello() -> String; 79 | #[doc = "attr"] 80 | async fn attr(s: String) -> String; 81 | async fn no_args_no_return(); 82 | async fn no_args() -> (); 83 | async fn one_arg(one: String) -> i32; 84 | async fn two_args_no_return(one: String, two: u64); 85 | async fn two_args(one: String, two: u64) -> String; 86 | async fn no_args_ret_error() -> i32; 87 | async fn one_arg_ret_error(one: String) -> String; 88 | async fn no_arg_implicit_return_error(); 89 | #[doc = "attr"] 90 | async fn one_arg_implicit_return_error(one: String); 91 | } 92 | } 93 | 94 | #[test] 95 | fn custom_derives() { 96 | #[tarpc::service(derive = [Clone, Hash])] 97 | trait Foo { 98 | async fn foo(); 99 | } 100 | 101 | fn requires_clone(_: impl Clone) {} 102 | fn requires_hash(_: impl Hash) {} 103 | 104 | let x = FooRequest::Foo {}; 105 | requires_clone(x.clone()); 106 | requires_hash(x); 107 | } 108 | 109 | #[test] 110 | fn implicit_serde() { 111 | #[tarpc::service] 112 | trait Foo { 113 | async fn foo(); 114 | } 115 | 116 | fn requires_serde(_: T) 117 | where 118 | for<'de> T: Serialize + Deserialize<'de>, 119 | { 120 | } 121 | 122 | let x = FooRequest::Foo {}; 123 | requires_serde(x); 124 | } 125 | 126 | #[allow(deprecated)] 127 | #[test] 128 | fn explicit_serde() { 129 | #[tarpc::service(derive_serde = true)] 130 | trait Foo { 131 | async fn foo(); 132 | } 133 | 134 | fn requires_serde(_: T) 135 | where 136 | for<'de> T: Serialize + Deserialize<'de>, 137 | { 138 | } 139 | 140 | let x = FooRequest::Foo {}; 141 | requires_serde(x); 142 | } 143 | -------------------------------------------------------------------------------- /tarpc/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tarpc" 3 | version = "0.36.0" 4 | rust-version = "1.65.0" 5 | authors = [ 6 | "Adam Wright ", 7 | "Tim Kuehn ", 8 | ] 9 | edition = "2021" 10 | license = "MIT" 11 | documentation = "https://docs.rs/tarpc" 12 | homepage = "https://github.com/google/tarpc" 13 | repository = "https://github.com/google/tarpc" 14 | keywords = ["rpc", "network", "server", "api", "microservices"] 15 | categories = ["asynchronous", "network-programming"] 16 | readme = "README.md" 17 | description = "An RPC framework for Rust with a focus on ease of use." 18 | 19 | [features] 20 | default = [] 21 | 22 | serde1 = ["tarpc-plugins/serde1", "serde", "serde/derive", "serde/rc"] 23 | tokio1 = ["tokio/rt"] 24 | serde-transport = ["serde1", "tokio1", "tokio-serde", "tokio-util/codec"] 25 | serde-transport-json = ["serde-transport", "tokio-serde/json"] 26 | serde-transport-bincode = ["serde-transport", "tokio-serde/bincode"] 27 | tcp = ["tokio/net"] 28 | unix = ["tokio/net"] 29 | 30 | full = [ 31 | "serde1", 32 | "tokio1", 33 | "serde-transport", 34 | "serde-transport-json", 35 | "serde-transport-bincode", 36 | "tcp", 37 | "unix", 38 | ] 39 | 40 | [badges] 41 | travis-ci = { repository = "google/tarpc" } 42 | 43 | [dependencies] 44 | anyhow = "1.0" 45 | fnv = "1.0" 46 | futures = "0.3" 47 | humantime = "2.0" 48 | pin-project = "1.0" 49 | rand = "0.8" 50 | serde = { optional = true, version = "1.0", features = ["derive"] } 51 | static_assertions = "1.1.0" 52 | tarpc-plugins = { path = "../plugins", version = "0.14" } 53 | thiserror = "2.0" 54 | tokio = { version = "1", features = ["time"] } 55 | tokio-util = { version = "0.7.3", features = ["time"] } 56 | tokio-serde = { optional = true, version = "0.9" } 57 | tracing = { version = "0.1", default-features = false, features = [ 58 | "attributes", 59 | "log", 60 | ] } 61 | tracing-opentelemetry = { version = "0.27.0", default-features = false } 62 | opentelemetry = { version = "0.26.0", default-features = false } 63 | opentelemetry-semantic-conventions = "0.16.0" 64 | 65 | [dev-dependencies] 66 | assert_matches = "1.4" 67 | bincode = "1.3" 68 | bytes = { version = "1", features = ["serde"] } 69 | flate2 = "1.0" 70 | futures-test = "0.3" 71 | opentelemetry = { version = "0.26.0", default-features = false } 72 | opentelemetry-otlp = "0.26.0" 73 | opentelemetry_sdk = { version = "0.26.0", features = ["rt-tokio"] } 74 | pin-utils = "0.1.0" 75 | serde_bytes = "0.11" 76 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 77 | tokio = { version = "1", features = ["full", "test-util", "tracing"] } 78 | console-subscriber = "0.4" 79 | tokio-serde = { version = "0.9", features = ["json", "bincode"] } 80 | trybuild = "1.0" 81 | tokio-rustls = "0.26" 82 | rustls-pemfile = "2.0" 83 | 84 | [package.metadata.docs.rs] 85 | all-features = true 86 | rustdoc-args = ["--cfg", "docsrs"] 87 | 88 | [[example]] 89 | name = "compression" 90 | required-features = ["serde-transport", "tcp"] 91 | 92 | [[example]] 93 | name = "tracing" 94 | required-features = ["full"] 95 | 96 | [[example]] 97 | name = "readme" 98 | required-features = ["full"] 99 | 100 | [[example]] 101 | name = "pubsub" 102 | required-features = ["full"] 103 | 104 | [[example]] 105 | name = "custom_transport" 106 | required-features = ["serde1", "tokio1", "serde-transport"] 107 | 108 | [[example]] 109 | name = "tls_over_tcp" 110 | required-features = ["full"] 111 | 112 | [[test]] 113 | name = "service_functional" 114 | required-features = ["serde-transport"] 115 | 116 | [[test]] 117 | name = "dataservice" 118 | required-features = ["serde-transport", "tcp"] 119 | -------------------------------------------------------------------------------- /tarpc/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright 2016 Google Inc. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /tarpc/README.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /tarpc/clippy.toml: -------------------------------------------------------------------------------- 1 | doc-valid-idents = ["gRPC"] 2 | -------------------------------------------------------------------------------- /tarpc/examples/certs/eddsa/client.cert: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBlDCCAUagAwIBAgICAxUwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk 3 | RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw 4 | NjEwMTEwNFowGjEYMBYGA1UEAwwPcG9ueXRvd24gY2xpZW50MCowBQYDK2VwAyEA 5 | NTKuLume19IhJfEFd/5OZUuYDKZH6xvy4AGver17OoejgZswgZgwDAYDVR0TAQH/ 6 | BAIwADALBgNVHQ8EBAMCBsAwFgYDVR0lAQH/BAwwCgYIKwYBBQUHAwIwHQYDVR0O 7 | BBYEFDjdrlMu4tyw5MHtbg7WnzSGRBpFMEQGA1UdIwQ9MDuAFHIl7fHKWP6/l8FE 8 | fI2YEIM3oHxKoSCkHjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQYIBezAF 9 | BgMrZXADQQCaahfj/QLxoCOpvl6y0ZQ9CpojPqBnxV3460j5nUOp040Va2MpF137 10 | izCBY7LwgUE/YG6E+kH30G4jMEnqVEYK 11 | -----END CERTIFICATE----- 12 | -------------------------------------------------------------------------------- /tarpc/examples/certs/eddsa/client.chain: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE 3 | U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD 4 | DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh 5 | AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU 6 | ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG 7 | AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU 8 | oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc 9 | zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg= 10 | -----END CERTIFICATE----- 11 | -----BEGIN CERTIFICATE----- 12 | MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG 13 | A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0 14 | MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh 15 | ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU 16 | phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR 17 | W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC 18 | t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB 19 | -----END CERTIFICATE----- 20 | -------------------------------------------------------------------------------- /tarpc/examples/certs/eddsa/client.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MC4CAQAwBQYDK2VwBCIEIIJX9ThTHpVS1SNZb6HP4myg4fRInIVGunTRdgnc+weH 3 | -----END PRIVATE KEY----- 4 | -------------------------------------------------------------------------------- /tarpc/examples/certs/eddsa/end.cert: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBuDCCAWqgAwIBAgICAcgwBQYDK2VwMC4xLDAqBgNVBAMMI3Bvbnl0b3duIEVk 3 | RFNBIGxldmVsIDIgaW50ZXJtZWRpYXRlMB4XDTIzMDMxNzEwMTEwNFoXDTI4MDkw 4 | NjEwMTEwNFowGTEXMBUGA1UEAwwOdGVzdHNlcnZlci5jb20wKjAFBgMrZXADIQDc 5 | RLl3/N2tPoWnzBV3noVn/oheEl8IUtiY11Vg/QXTUKOBwDCBvTAMBgNVHRMBAf8E 6 | AjAAMAsGA1UdDwQEAwIGwDAdBgNVHQ4EFgQUk7U2mnxedNWBAH84BsNy5si3ZQow 7 | RAYDVR0jBD0wO4AUciXt8cpY/r+XwUR8jZgQgzegfEqhIKQeMBwxGjAYBgNVBAMM 8 | EXBvbnl0b3duIEVkRFNBIENBggF7MDsGA1UdEQQ0MDKCDnRlc3RzZXJ2ZXIuY29t 9 | ghVzZWNvbmQudGVzdHNlcnZlci5jb22CCWxvY2FsaG9zdDAFBgMrZXADQQCFWIcF 10 | 9FiztCuUNzgXDNu5kshuflt0RjkjWpGlWzQjGoYM2IvYhNVPeqnCiY92gqwDSBtq 11 | amD2TBup4eNUCsQB 12 | -----END CERTIFICATE----- 13 | -------------------------------------------------------------------------------- /tarpc/examples/certs/eddsa/end.chain: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBeDCCASqgAwIBAgIBezAFBgMrZXAwHDEaMBgGA1UEAwwRcG9ueXRvd24gRWRE 3 | U0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0MTAxMTA0WjAuMSwwKgYDVQQD 4 | DCNwb255dG93biBFZERTQSBsZXZlbCAyIGludGVybWVkaWF0ZTAqMAUGAytlcAMh 5 | AEFsAexz4x2R4k4+PnTbvRVn0r3F/qw/zVnNBxfGcoEpo38wfTAdBgNVHQ4EFgQU 6 | ciXt8cpY/r+XwUR8jZgQgzegfEowIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwEGCCsG 7 | AQUFBwMCMAwGA1UdEwQFMAMBAf8wCwYDVR0PBAQDAgH+MB8GA1UdIwQYMBaAFKYU 8 | oLdKeY7mp7QgMZKrkVtSWYBKMAUGAytlcANBAHVpNpCV8nu4fkH3Smikx5A9qtHc 9 | zgLIyp+wrF1a4YSa6sfTvuQmJd5aF23OXgq5grCOPXtdpHO50Mx5Qy74zQg= 10 | -----END CERTIFICATE----- 11 | -----BEGIN CERTIFICATE----- 12 | MIIBTDCB/6ADAgECAhRZLuF0TWjDs/31OO8VeKHkNIJQaDAFBgMrZXAwHDEaMBgG 13 | A1UEAwwRcG9ueXRvd24gRWREU0EgQ0EwHhcNMjMwMzE3MTAxMTA0WhcNMzMwMzE0 14 | MTAxMTA0WjAcMRowGAYDVQQDDBFwb255dG93biBFZERTQSBDQTAqMAUGAytlcAMh 15 | ABRPZ4TiuBE8CqAFByZvqpMo/unjnnryfG2AkkWGXpa3o1MwUTAdBgNVHQ4EFgQU 16 | phSgt0p5juantCAxkquRW1JZgEowHwYDVR0jBBgwFoAUphSgt0p5juantCAxkquR 17 | W1JZgEowDwYDVR0TAQH/BAUwAwEB/zAFBgMrZXADQQB29o8erJA0/a8/xOHilOCC 18 | t/s5wPHHnS5NSKx/m2N2nRn3zPxEnETlrAmGulJoeKOx8OblwmPi9rBT2K+QY2UB 19 | -----END CERTIFICATE----- 20 | -------------------------------------------------------------------------------- /tarpc/examples/certs/eddsa/end.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MC4CAQAwBQYDK2VwBCIEIMU6xGVe8JTpZ3bN/wajHfw6pEHt0Rd7wPBxds9eEFy2 3 | -----END PRIVATE KEY----- 4 | -------------------------------------------------------------------------------- /tarpc/examples/compression.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2022 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression}; 8 | use futures::{prelude::*, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; 9 | use serde::{Deserialize, Serialize}; 10 | use serde_bytes::ByteBuf; 11 | use std::{io, io::Read, io::Write}; 12 | use tarpc::{ 13 | client, context, 14 | serde_transport::tcp, 15 | server::{BaseChannel, Channel}, 16 | tokio_serde::formats::Bincode, 17 | }; 18 | 19 | /// Type of compression that should be enabled on the request. The transport is free to ignore this. 20 | #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] 21 | pub enum CompressionAlgorithm { 22 | Deflate, 23 | } 24 | 25 | #[derive(Debug, Deserialize, Serialize)] 26 | pub enum CompressedMessage { 27 | Uncompressed(T), 28 | Compressed { 29 | algorithm: CompressionAlgorithm, 30 | payload: ByteBuf, 31 | }, 32 | } 33 | 34 | #[derive(Deserialize, Serialize)] 35 | enum CompressionType { 36 | Uncompressed, 37 | Compressed, 38 | } 39 | 40 | async fn compress(message: T) -> io::Result> 41 | where 42 | T: Serialize, 43 | { 44 | let message = serialize(message)?; 45 | let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default()); 46 | encoder.write_all(&message).unwrap(); 47 | let compressed = encoder.finish()?; 48 | Ok(CompressedMessage::Compressed { 49 | algorithm: CompressionAlgorithm::Deflate, 50 | payload: ByteBuf::from(compressed), 51 | }) 52 | } 53 | 54 | async fn decompress(message: CompressedMessage) -> io::Result 55 | where 56 | for<'a> T: Deserialize<'a>, 57 | { 58 | match message { 59 | CompressedMessage::Compressed { algorithm, payload } => { 60 | if algorithm != CompressionAlgorithm::Deflate { 61 | return Err(io::Error::new( 62 | io::ErrorKind::InvalidData, 63 | format!("Compression algorithm {algorithm:?} not supported"), 64 | )); 65 | } 66 | let mut deflater = DeflateDecoder::new(payload.as_slice()); 67 | let mut payload = ByteBuf::new(); 68 | deflater.read_to_end(&mut payload)?; 69 | let message = deserialize(payload)?; 70 | Ok(message) 71 | } 72 | CompressedMessage::Uncompressed(message) => Ok(message), 73 | } 74 | } 75 | 76 | fn serialize(t: T) -> io::Result { 77 | bincode::serialize(&t) 78 | .map(ByteBuf::from) 79 | .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) 80 | } 81 | 82 | fn deserialize(message: ByteBuf) -> io::Result 83 | where 84 | for<'a> D: Deserialize<'a>, 85 | { 86 | bincode::deserialize(message.as_ref()).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) 87 | } 88 | 89 | fn add_compression( 90 | transport: impl Stream>> 91 | + Sink, Error = io::Error>, 92 | ) -> impl Stream> + Sink 93 | where 94 | Out: Serialize, 95 | for<'a> In: Deserialize<'a>, 96 | { 97 | transport.with(compress).and_then(decompress) 98 | } 99 | 100 | #[tarpc::service] 101 | pub trait World { 102 | async fn hello(name: String) -> String; 103 | } 104 | 105 | #[derive(Clone, Debug)] 106 | struct HelloServer; 107 | 108 | impl World for HelloServer { 109 | async fn hello(self, _: context::Context, name: String) -> String { 110 | format!("Hey, {name}!") 111 | } 112 | } 113 | 114 | async fn spawn(fut: impl Future + Send + 'static) { 115 | tokio::spawn(fut); 116 | } 117 | 118 | #[tokio::main] 119 | async fn main() -> anyhow::Result<()> { 120 | let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; 121 | let addr = incoming.local_addr(); 122 | tokio::spawn(async move { 123 | let transport = incoming.next().await.unwrap().unwrap(); 124 | BaseChannel::with_defaults(add_compression(transport)) 125 | .execute(HelloServer.serve()) 126 | .for_each(spawn) 127 | .await; 128 | }); 129 | 130 | let transport = tcp::connect(addr, Bincode::default).await?; 131 | let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn(); 132 | 133 | println!( 134 | "{}", 135 | client.hello(context::current(), "friend".into()).await? 136 | ); 137 | Ok(()) 138 | } 139 | -------------------------------------------------------------------------------- /tarpc/examples/custom_transport.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2022 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use futures::prelude::*; 8 | use tarpc::context::Context; 9 | use tarpc::serde_transport as transport; 10 | use tarpc::server::{BaseChannel, Channel}; 11 | use tarpc::tokio_serde::formats::Bincode; 12 | use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; 13 | use tokio::net::{UnixListener, UnixStream}; 14 | 15 | #[tarpc::service] 16 | pub trait PingService { 17 | async fn ping(); 18 | } 19 | 20 | #[derive(Clone)] 21 | struct Service; 22 | 23 | impl PingService for Service { 24 | async fn ping(self, _: Context) {} 25 | } 26 | 27 | #[tokio::main] 28 | async fn main() -> anyhow::Result<()> { 29 | let bind_addr = "/tmp/tarpc_on_unix_example.sock"; 30 | 31 | let _ = std::fs::remove_file(bind_addr); 32 | 33 | let listener = UnixListener::bind(bind_addr).unwrap(); 34 | let codec_builder = LengthDelimitedCodec::builder(); 35 | async fn spawn(fut: impl Future + Send + 'static) { 36 | tokio::spawn(fut); 37 | } 38 | tokio::spawn(async move { 39 | loop { 40 | let (conn, _addr) = listener.accept().await.unwrap(); 41 | let framed = codec_builder.new_framed(conn); 42 | let transport = transport::new(framed, Bincode::default()); 43 | 44 | let fut = BaseChannel::with_defaults(transport) 45 | .execute(Service.serve()) 46 | .for_each(spawn); 47 | tokio::spawn(fut); 48 | } 49 | }); 50 | 51 | let conn = UnixStream::connect(bind_addr).await?; 52 | let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); 53 | PingServiceClient::new(Default::default(), transport) 54 | .spawn() 55 | .ping(tarpc::context::current()) 56 | .await?; 57 | 58 | Ok(()) 59 | } 60 | -------------------------------------------------------------------------------- /tarpc/examples/pubsub.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | /// - The PubSub server sets up TCP listeners on 2 ports, the "subscriber" port and the "publisher" 8 | /// port. Because both publishers and subscribers initiate their connections to the PubSub 9 | /// server, the server requires no prior knowledge of either publishers or subscribers. 10 | /// 11 | /// - Subscribers connect to the server on the server's "subscriber" port. Once a connection is 12 | /// established, the server acts as the client of the Subscriber service, initially requesting 13 | /// the topics the subscriber is interested in, and subsequently sending topical messages to the 14 | /// subscriber. 15 | /// 16 | /// - Publishers connect to the server on the "publisher" port and, once connected, they send 17 | /// topical messages via Publisher service to the server. The server then broadcasts each 18 | /// messages to all clients subscribed to the topic of that message. 19 | /// 20 | /// Subscriber Publisher PubSub Server 21 | /// T1 | | | 22 | /// T2 |-----Connect------------------------------------------------------>| 23 | /// T3 | | | 24 | /// T2 |<-------------------------------------------------------Topics-----| 25 | /// T2 |-----(OK) Topics-------------------------------------------------->| 26 | /// T3 | | | 27 | /// T4 | |-----Connect-------------------->| 28 | /// T5 | | | 29 | /// T6 | |-----Publish-------------------->| 30 | /// T7 | | | 31 | /// T8 |<------------------------------------------------------Receive-----| 32 | /// T9 |-----(OK) Receive------------------------------------------------->| 33 | /// T10 | | | 34 | /// T11 | |<--------------(OK) Publish------| 35 | use anyhow::anyhow; 36 | use futures::{ 37 | channel::oneshot, 38 | future::{self, AbortHandle}, 39 | prelude::*, 40 | }; 41 | use opentelemetry::trace::TracerProvider as _; 42 | use publisher::Publisher as _; 43 | use std::{ 44 | collections::HashMap, 45 | error::Error, 46 | io, 47 | net::SocketAddr, 48 | sync::{Arc, Mutex, RwLock}, 49 | }; 50 | use subscriber::Subscriber as _; 51 | use tarpc::{ 52 | client, context, 53 | serde_transport::tcp, 54 | server::{self, Channel}, 55 | tokio_serde::formats::Json, 56 | }; 57 | use tokio::net::ToSocketAddrs; 58 | use tracing::info; 59 | use tracing_subscriber::prelude::*; 60 | 61 | pub mod subscriber { 62 | #[tarpc::service] 63 | pub trait Subscriber { 64 | async fn topics() -> Vec; 65 | async fn receive(topic: String, message: String); 66 | } 67 | } 68 | 69 | pub mod publisher { 70 | #[tarpc::service] 71 | pub trait Publisher { 72 | async fn publish(topic: String, message: String); 73 | } 74 | } 75 | 76 | #[derive(Clone, Debug)] 77 | struct Subscriber { 78 | local_addr: SocketAddr, 79 | topics: Vec, 80 | } 81 | 82 | impl subscriber::Subscriber for Subscriber { 83 | async fn topics(self, _: context::Context) -> Vec { 84 | self.topics.clone() 85 | } 86 | 87 | async fn receive(self, _: context::Context, topic: String, message: String) { 88 | info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") 89 | } 90 | } 91 | 92 | struct SubscriberHandle(AbortHandle); 93 | 94 | impl Drop for SubscriberHandle { 95 | fn drop(&mut self) { 96 | self.0.abort(); 97 | } 98 | } 99 | 100 | impl Subscriber { 101 | async fn connect( 102 | publisher_addr: impl ToSocketAddrs, 103 | topics: Vec, 104 | ) -> anyhow::Result { 105 | let publisher = tcp::connect(publisher_addr, Json::default).await?; 106 | let local_addr = publisher.local_addr()?; 107 | let mut handler = server::BaseChannel::with_defaults(publisher).requests(); 108 | let subscriber = Subscriber { local_addr, topics }; 109 | // The first request is for the topics being subscribed to. 110 | match handler.next().await { 111 | Some(init_topics) => init_topics?.execute(subscriber.clone().serve()).await, 112 | None => { 113 | return Err(anyhow!( 114 | "[{}] Server never initialized the subscriber.", 115 | local_addr 116 | )) 117 | } 118 | }; 119 | let (handler, abort_handle) = 120 | future::abortable(handler.execute(subscriber.serve()).for_each(spawn)); 121 | tokio::spawn(async move { 122 | match handler.await { 123 | Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."), 124 | } 125 | }); 126 | Ok(SubscriberHandle(abort_handle)) 127 | } 128 | } 129 | 130 | #[derive(Debug)] 131 | struct Subscription { 132 | topics: Vec, 133 | } 134 | 135 | #[derive(Clone, Debug)] 136 | struct Publisher { 137 | clients: Arc>>, 138 | subscriptions: Arc>>>, 139 | } 140 | 141 | struct PublisherAddrs { 142 | publisher: SocketAddr, 143 | subscriptions: SocketAddr, 144 | } 145 | 146 | async fn spawn(fut: impl Future + Send + 'static) { 147 | tokio::spawn(fut); 148 | } 149 | 150 | impl Publisher { 151 | async fn start(self) -> io::Result { 152 | let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; 153 | 154 | let publisher_addrs = PublisherAddrs { 155 | publisher: connecting_publishers.local_addr(), 156 | subscriptions: self.clone().start_subscription_manager().await?, 157 | }; 158 | 159 | info!(publisher_addr = %publisher_addrs.publisher, "listening for publishers.",); 160 | tokio::spawn(async move { 161 | // Because this is just an example, we know there will only be one publisher. In more 162 | // realistic code, this would be a loop to continually accept new publisher 163 | // connections. 164 | let publisher = connecting_publishers.next().await.unwrap().unwrap(); 165 | info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); 166 | 167 | server::BaseChannel::with_defaults(publisher) 168 | .execute(self.serve()) 169 | .for_each(spawn) 170 | .await 171 | }); 172 | 173 | Ok(publisher_addrs) 174 | } 175 | 176 | async fn start_subscription_manager(mut self) -> io::Result { 177 | let mut connecting_subscribers = tcp::listen("localhost:0", Json::default) 178 | .await? 179 | .filter_map(|r| future::ready(r.ok())); 180 | let new_subscriber_addr = connecting_subscribers.get_ref().local_addr(); 181 | info!(?new_subscriber_addr, "listening for subscribers."); 182 | 183 | tokio::spawn(async move { 184 | while let Some(conn) = connecting_subscribers.next().await { 185 | let subscriber_addr = conn.peer_addr().unwrap(); 186 | 187 | let tarpc::client::NewClient { 188 | client: subscriber, 189 | dispatch, 190 | } = subscriber::SubscriberClient::new(client::Config::default(), conn); 191 | let (ready_tx, ready) = oneshot::channel(); 192 | self.clone() 193 | .start_subscriber_gc(subscriber_addr, dispatch, ready); 194 | 195 | // Populate the topics 196 | self.initialize_subscription(subscriber_addr, subscriber) 197 | .await; 198 | 199 | // Signal that initialization is done. 200 | ready_tx.send(()).unwrap(); 201 | } 202 | }); 203 | 204 | Ok(new_subscriber_addr) 205 | } 206 | 207 | async fn initialize_subscription( 208 | &mut self, 209 | subscriber_addr: SocketAddr, 210 | subscriber: subscriber::SubscriberClient, 211 | ) { 212 | // Populate the topics 213 | if let Ok(topics) = subscriber.topics(context::current()).await { 214 | self.clients.lock().unwrap().insert( 215 | subscriber_addr, 216 | Subscription { 217 | topics: topics.clone(), 218 | }, 219 | ); 220 | 221 | info!(%subscriber_addr, ?topics, "subscribed to new topics"); 222 | let mut subscriptions = self.subscriptions.write().unwrap(); 223 | for topic in topics { 224 | subscriptions 225 | .entry(topic) 226 | .or_default() 227 | .insert(subscriber_addr, subscriber.clone()); 228 | } 229 | } 230 | } 231 | 232 | fn start_subscriber_gc( 233 | self, 234 | subscriber_addr: SocketAddr, 235 | client_dispatch: impl Future> + Send + 'static, 236 | subscriber_ready: oneshot::Receiver<()>, 237 | ) { 238 | tokio::spawn(async move { 239 | if let Err(e) = client_dispatch.await { 240 | info!( 241 | %subscriber_addr, 242 | error = %e, 243 | "subscriber connection broken"); 244 | } 245 | // Don't clean up the subscriber until initialization is done. 246 | let _ = subscriber_ready.await; 247 | if let Some(subscription) = self.clients.lock().unwrap().remove(&subscriber_addr) { 248 | info!( 249 | "[{} unsubscribing from topics: {:?}", 250 | subscriber_addr, subscription.topics 251 | ); 252 | let mut subscriptions = self.subscriptions.write().unwrap(); 253 | for topic in subscription.topics { 254 | let subscribers = subscriptions.get_mut(&topic).unwrap(); 255 | subscribers.remove(&subscriber_addr); 256 | if subscribers.is_empty() { 257 | subscriptions.remove(&topic); 258 | } 259 | } 260 | } 261 | }); 262 | } 263 | } 264 | 265 | impl publisher::Publisher for Publisher { 266 | async fn publish(self, _: context::Context, topic: String, message: String) { 267 | info!("received message to publish."); 268 | let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { 269 | None => return, 270 | Some(subscriptions) => subscriptions.clone(), 271 | }; 272 | let mut publications = Vec::new(); 273 | for client in subscribers.values_mut() { 274 | publications.push(client.receive(context::current(), topic.clone(), message.clone())); 275 | } 276 | // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until 277 | // subscribers ack. Of course, a lot would be different in a real pubsub :) 278 | for response in future::join_all(publications).await { 279 | if let Err(e) = response { 280 | info!("failed to broadcast to subscriber: {}", e); 281 | } 282 | } 283 | } 284 | } 285 | 286 | /// Initializes an OpenTelemetry tracing subscriber with a OTLP backend. 287 | pub fn init_tracing(service_name: &'static str) -> anyhow::Result<()> { 288 | let tracer_provider = opentelemetry_otlp::new_pipeline() 289 | .tracing() 290 | .with_batch_config(opentelemetry_sdk::trace::BatchConfig::default()) 291 | .with_exporter(opentelemetry_otlp::new_exporter().tonic()) 292 | .with_trace_config(opentelemetry_sdk::trace::Config::default().with_resource( 293 | opentelemetry_sdk::Resource::new([opentelemetry::KeyValue::new( 294 | opentelemetry_semantic_conventions::resource::SERVICE_NAME, 295 | service_name, 296 | )]), 297 | )) 298 | .install_batch(opentelemetry_sdk::runtime::Tokio)?; 299 | opentelemetry::global::set_tracer_provider(tracer_provider.clone()); 300 | let tracer = tracer_provider.tracer(service_name); 301 | 302 | tracing_subscriber::registry() 303 | .with(tracing_subscriber::EnvFilter::from_default_env()) 304 | .with(tracing_subscriber::fmt::layer()) 305 | .with(tracing_opentelemetry::layer().with_tracer(tracer)) 306 | .try_init()?; 307 | 308 | Ok(()) 309 | } 310 | 311 | #[tokio::main] 312 | async fn main() -> anyhow::Result<()> { 313 | init_tracing("Pub/Sub")?; 314 | 315 | let addrs = Publisher { 316 | clients: Arc::new(Mutex::new(HashMap::new())), 317 | subscriptions: Arc::new(RwLock::new(HashMap::new())), 318 | } 319 | .start() 320 | .await?; 321 | 322 | let _subscriber0 = Subscriber::connect( 323 | addrs.subscriptions, 324 | vec!["calculus".into(), "cool shorts".into()], 325 | ) 326 | .await?; 327 | 328 | let _subscriber1 = Subscriber::connect( 329 | addrs.subscriptions, 330 | vec!["cool shorts".into(), "history".into()], 331 | ) 332 | .await?; 333 | 334 | let publisher = publisher::PublisherClient::new( 335 | client::Config::default(), 336 | tcp::connect(addrs.publisher, Json::default).await?, 337 | ) 338 | .spawn(); 339 | 340 | publisher 341 | .publish(context::current(), "calculus".into(), "sqrt(2)".into()) 342 | .await?; 343 | 344 | publisher 345 | .publish( 346 | context::current(), 347 | "cool shorts".into(), 348 | "hello to all".into(), 349 | ) 350 | .await?; 351 | 352 | publisher 353 | .publish(context::current(), "history".into(), "napoleon".to_string()) 354 | .await?; 355 | 356 | drop(_subscriber0); 357 | 358 | publisher 359 | .publish( 360 | context::current(), 361 | "cool shorts".into(), 362 | "hello to who?".into(), 363 | ) 364 | .await?; 365 | 366 | opentelemetry::global::shutdown_tracer_provider(); 367 | info!("done."); 368 | 369 | Ok(()) 370 | } 371 | -------------------------------------------------------------------------------- /tarpc/examples/readme.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use futures::prelude::*; 8 | use tarpc::{ 9 | client, context, 10 | server::{self, Channel}, 11 | }; 12 | 13 | /// This is the service definition. It looks a lot like a trait definition. 14 | /// It defines one RPC, hello, which takes one arg, name, and returns a String. 15 | #[tarpc::service] 16 | pub trait World { 17 | async fn hello(name: String) -> String; 18 | } 19 | 20 | /// This is the type that implements the generated World trait. It is the business logic 21 | /// and is used to start the server. 22 | #[derive(Clone)] 23 | struct HelloServer; 24 | 25 | impl World for HelloServer { 26 | async fn hello(self, _: context::Context, name: String) -> String { 27 | format!("Hello, {name}!") 28 | } 29 | } 30 | 31 | async fn spawn(fut: impl Future + Send + 'static) { 32 | tokio::spawn(fut); 33 | } 34 | 35 | #[tokio::main] 36 | async fn main() -> anyhow::Result<()> { 37 | let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); 38 | 39 | let server = server::BaseChannel::with_defaults(server_transport); 40 | tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); 41 | 42 | // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` 43 | // that takes a config and any Transport as input. 44 | let client = WorldClient::new(client::Config::default(), client_transport).spawn(); 45 | 46 | // The client has an RPC method for each RPC defined in the annotated trait. It takes the same 47 | // args as defined, with the addition of a Context, which is always the first arg. The Context 48 | // specifies a deadline and trace information which can be helpful in debugging requests. 49 | let hello = client.hello(context::current(), "Stim".to_string()).await?; 50 | 51 | println!("{hello}"); 52 | 53 | Ok(()) 54 | } 55 | -------------------------------------------------------------------------------- /tarpc/examples/tls_over_tcp.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2023 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use futures::prelude::*; 8 | use rustls_pemfile::certs; 9 | use std::io::{self, BufReader, Cursor}; 10 | use std::net::{IpAddr, Ipv4Addr}; 11 | 12 | use std::sync::Arc; 13 | use tokio::net::TcpListener; 14 | use tokio::net::TcpStream; 15 | use tokio_rustls::rustls::{ 16 | self, 17 | server::{danger::ClientCertVerifier, WebPkiClientVerifier}, 18 | RootCertStore, 19 | }; 20 | use tokio_rustls::{TlsAcceptor, TlsConnector}; 21 | 22 | use tarpc::context::Context; 23 | use tarpc::serde_transport as transport; 24 | use tarpc::server::{BaseChannel, Channel}; 25 | use tarpc::tokio_serde::formats::Bincode; 26 | use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; 27 | 28 | #[tarpc::service] 29 | pub trait PingService { 30 | async fn ping() -> String; 31 | } 32 | 33 | #[derive(Clone)] 34 | struct Service; 35 | 36 | impl PingService for Service { 37 | async fn ping(self, _: Context) -> String { 38 | "🔒".to_owned() 39 | } 40 | } 41 | 42 | // certs were generated with openssl 3 https://github.com/rustls/rustls/tree/main/test-ca 43 | // used on client-side for server tls 44 | const END_CHAIN: &str = include_str!("certs/eddsa/end.chain"); 45 | // used on client-side for client-auth 46 | const CLIENT_PRIVATEKEY_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.key"); 47 | const CLIENT_CERT_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.cert"); 48 | 49 | // used on server-side for server tls 50 | const END_CERT: &str = include_str!("certs/eddsa/end.cert"); 51 | const END_PRIVATEKEY: &str = include_str!("certs/eddsa/end.key"); 52 | // used on server-side for client-auth 53 | const CLIENT_CHAIN_CLIENT_AUTH: &str = include_str!("certs/eddsa/client.chain"); 54 | 55 | pub fn load_certs(data: &str) -> Vec> { 56 | certs(&mut BufReader::new(Cursor::new(data))) 57 | .map(|result| result.unwrap()) 58 | .collect() 59 | } 60 | 61 | pub fn load_private_key(key: &str) -> rustls::pki_types::PrivateKeyDer { 62 | let mut reader = BufReader::new(Cursor::new(key)); 63 | loop { 64 | match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") { 65 | Some(rustls_pemfile::Item::Pkcs1Key(key)) => return key.into(), 66 | Some(rustls_pemfile::Item::Pkcs8Key(key)) => return key.into(), 67 | Some(rustls_pemfile::Item::Sec1Key(key)) => return key.into(), 68 | None => break, 69 | _ => continue, 70 | } 71 | } 72 | panic!("no keys found in {:?} (encrypted keys not supported)", key); 73 | } 74 | 75 | async fn spawn(fut: impl Future + Send + 'static) { 76 | tokio::spawn(fut); 77 | } 78 | 79 | #[tokio::main] 80 | async fn main() -> anyhow::Result<()> { 81 | // -------------------- start here to setup tls tcp tokio stream -------------------------- 82 | // ref certs and loading from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/tests/test.rs 83 | // ref basic tls server setup from: https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/server/src/main.rs 84 | let cert = load_certs(END_CERT); 85 | let key = load_private_key(END_PRIVATEKEY); 86 | let server_addr = (IpAddr::V4(Ipv4Addr::LOCALHOST), 5000); 87 | 88 | // ------------- server side client_auth cert loading start 89 | let mut client_auth_roots = RootCertStore::empty(); 90 | for root in load_certs(CLIENT_CHAIN_CLIENT_AUTH) { 91 | client_auth_roots.add(root).unwrap(); 92 | } 93 | 94 | let client_auth: Arc = WebPkiClientVerifier::builder( 95 | // allow only certificates signed by a trusted CA 96 | client_auth_roots.into(), 97 | ) 98 | .build() 99 | .map_err(|err| io::Error::new(io::ErrorKind::Other, format!("{}", err))) 100 | .unwrap(); 101 | // ------------- server side client_auth cert loading end 102 | 103 | let config = rustls::ServerConfig::builder() 104 | .with_client_cert_verifier(client_auth) // use .with_no_client_auth() instead if you don't want client-auth 105 | .with_single_cert(cert, key) 106 | .unwrap(); 107 | let acceptor = TlsAcceptor::from(Arc::new(config)); 108 | let listener = TcpListener::bind(&server_addr).await.unwrap(); 109 | let codec_builder = LengthDelimitedCodec::builder(); 110 | 111 | // ref ./custom_transport.rs server side 112 | tokio::spawn(async move { 113 | loop { 114 | let (stream, _peer_addr) = listener.accept().await.unwrap(); 115 | let tls_stream = acceptor.accept(stream).await.unwrap(); 116 | let framed = codec_builder.new_framed(tls_stream); 117 | 118 | let transport = transport::new(framed, Bincode::default()); 119 | 120 | let fut = BaseChannel::with_defaults(transport) 121 | .execute(Service.serve()) 122 | .for_each(spawn); 123 | tokio::spawn(fut); 124 | } 125 | }); 126 | 127 | // ---------------------- client connection --------------------- 128 | // tls client connection from https://github.com/tokio-rs/tls/blob/master/tokio-rustls/examples/client/src/main.rs 129 | let mut root_store = rustls::RootCertStore::empty(); 130 | for root in load_certs(END_CHAIN) { 131 | root_store.add(root).unwrap(); 132 | } 133 | 134 | let client_auth_private_key = load_private_key(CLIENT_PRIVATEKEY_CLIENT_AUTH); 135 | let client_auth_certs = load_certs(CLIENT_CERT_CLIENT_AUTH); 136 | 137 | let config = rustls::ClientConfig::builder() 138 | .with_root_certificates(root_store) 139 | .with_client_auth_cert(client_auth_certs, client_auth_private_key)?; // use .with_no_client_auth() instead if you don't want client-auth 140 | 141 | let domain = rustls::pki_types::ServerName::try_from("localhost")?; 142 | let connector = TlsConnector::from(Arc::new(config)); 143 | 144 | let stream = TcpStream::connect(server_addr).await?; 145 | let stream = connector.connect(domain, stream).await?; 146 | 147 | let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); 148 | let answer = PingServiceClient::new(Default::default(), transport) 149 | .spawn() 150 | .ping(tarpc::context::current()) 151 | .await?; 152 | 153 | println!("ping answer: {answer}"); 154 | 155 | Ok(()) 156 | } 157 | -------------------------------------------------------------------------------- /tarpc/examples/tracing.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | #![allow(clippy::type_complexity)] 8 | 9 | use crate::{ 10 | add::{Add as AddService, AddStub}, 11 | double::Double as DoubleService, 12 | }; 13 | use futures::{future, prelude::*}; 14 | use opentelemetry::trace::TracerProvider as _; 15 | use std::{ 16 | io, 17 | sync::{ 18 | atomic::{AtomicBool, Ordering}, 19 | Arc, 20 | }, 21 | }; 22 | use tarpc::{ 23 | client::{ 24 | self, 25 | stub::{load_balance, retry}, 26 | RpcError, 27 | }, 28 | context, serde_transport, 29 | server::{ 30 | incoming::{spawn_incoming, Incoming}, 31 | request_hook::{self, BeforeRequestList}, 32 | BaseChannel, 33 | }, 34 | tokio_serde::formats::Json, 35 | ClientMessage, RequestName, Response, ServerError, Transport, 36 | }; 37 | use tokio::net::TcpStream; 38 | use tracing_subscriber::prelude::*; 39 | 40 | pub mod add { 41 | #[tarpc::service] 42 | pub trait Add { 43 | /// Add two ints together. 44 | async fn add(x: i32, y: i32) -> i32; 45 | } 46 | } 47 | 48 | pub mod double { 49 | #[tarpc::service] 50 | pub trait Double { 51 | /// 2 * x 52 | async fn double(x: i32) -> Result; 53 | } 54 | } 55 | 56 | #[derive(Clone)] 57 | struct AddServer; 58 | 59 | impl AddService for AddServer { 60 | async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { 61 | x + y 62 | } 63 | } 64 | 65 | #[derive(Clone)] 66 | struct DoubleServer { 67 | add_client: add::AddClient, 68 | } 69 | 70 | impl DoubleService for DoubleServer 71 | where 72 | Stub: AddStub + Clone + Send + Sync + 'static, 73 | { 74 | async fn double(self, _: context::Context, x: i32) -> Result { 75 | self.add_client 76 | .add(context::current(), x, x) 77 | .await 78 | .map_err(|e| e.to_string()) 79 | } 80 | } 81 | 82 | /// Initializes an OpenTelemetry tracing subscriber with a OTLP backend. 83 | pub fn init_tracing(service_name: &'static str) -> anyhow::Result<()> { 84 | let tracer_provider = opentelemetry_otlp::new_pipeline() 85 | .tracing() 86 | .with_batch_config(opentelemetry_sdk::trace::BatchConfig::default()) 87 | .with_exporter(opentelemetry_otlp::new_exporter().tonic()) 88 | .with_trace_config(opentelemetry_sdk::trace::Config::default().with_resource( 89 | opentelemetry_sdk::Resource::new([opentelemetry::KeyValue::new( 90 | opentelemetry_semantic_conventions::resource::SERVICE_NAME, 91 | service_name, 92 | )]), 93 | )) 94 | .install_batch(opentelemetry_sdk::runtime::Tokio)?; 95 | opentelemetry::global::set_tracer_provider(tracer_provider.clone()); 96 | let tracer = tracer_provider.tracer(service_name); 97 | 98 | tracing_subscriber::registry() 99 | .with(tracing_subscriber::EnvFilter::from_default_env()) 100 | .with(tracing_subscriber::fmt::layer()) 101 | .with(tracing_opentelemetry::layer().with_tracer(tracer)) 102 | .try_init()?; 103 | 104 | Ok(()) 105 | } 106 | 107 | async fn listen_on_random_port() -> anyhow::Result<( 108 | impl Stream>>, 109 | std::net::SocketAddr, 110 | )> 111 | where 112 | Item: for<'de> serde::Deserialize<'de>, 113 | SinkItem: serde::Serialize, 114 | { 115 | let listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) 116 | .await? 117 | .filter_map(|r| future::ready(r.ok())) 118 | .take(1); 119 | let addr = listener.get_ref().get_ref().local_addr(); 120 | Ok((listener, addr)) 121 | } 122 | 123 | fn make_stub( 124 | backends: [impl Transport>, Response> + Send + Sync + 'static; N], 125 | ) -> retry::Retry< 126 | impl Fn(&Result, u32) -> bool + Clone, 127 | load_balance::RoundRobin, Resp>>, 128 | > 129 | where 130 | Req: RequestName + Send + Sync + 'static, 131 | Resp: Send + Sync + 'static, 132 | { 133 | let stub = load_balance::RoundRobin::new( 134 | backends 135 | .into_iter() 136 | .map(|transport| tarpc::client::new(client::Config::default(), transport).spawn()) 137 | .collect(), 138 | ); 139 | retry::Retry::new(stub, |resp, attempts| { 140 | if let Err(e) = resp { 141 | tracing::warn!("Got an error: {e:?}"); 142 | attempts < 3 143 | } else { 144 | false 145 | } 146 | }) 147 | } 148 | 149 | #[tokio::main] 150 | async fn main() -> anyhow::Result<()> { 151 | init_tracing("tarpc_tracing_example")?; 152 | 153 | let (add_listener1, addr1) = listen_on_random_port().await?; 154 | let (add_listener2, addr2) = listen_on_random_port().await?; 155 | let something_bad_happened = Arc::new(AtomicBool::new(false)); 156 | let server = request_hook::before() 157 | .then_fn(move |_: &mut _, _: &_| { 158 | let something_bad_happened = something_bad_happened.clone(); 159 | async move { 160 | if something_bad_happened.fetch_xor(true, Ordering::Relaxed) { 161 | Err(ServerError::new( 162 | io::ErrorKind::NotFound, 163 | "Gamma Ray!".into(), 164 | )) 165 | } else { 166 | Ok(()) 167 | } 168 | } 169 | }) 170 | .serving(AddServer.serve()); 171 | let add_server = add_listener1 172 | .chain(add_listener2) 173 | .map(BaseChannel::with_defaults); 174 | tokio::spawn(spawn_incoming(add_server.execute(server))); 175 | 176 | let add_client = add::AddClient::from(make_stub([ 177 | tarpc::serde_transport::tcp::connect(addr1, Json::default).await?, 178 | tarpc::serde_transport::tcp::connect(addr2, Json::default).await?, 179 | ])); 180 | 181 | let double_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) 182 | .await? 183 | .filter_map(|r| future::ready(r.ok())); 184 | let addr = double_listener.get_ref().local_addr(); 185 | let double_server = double_listener.map(BaseChannel::with_defaults).take(1); 186 | let server = DoubleServer { add_client }.serve(); 187 | tokio::spawn(spawn_incoming(double_server.execute(server))); 188 | 189 | let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; 190 | let double_client = 191 | double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); 192 | 193 | let ctx = context::current(); 194 | for _ in 1..=5 { 195 | tracing::info!("{:?}", double_client.double(ctx, 1).await?); 196 | } 197 | 198 | opentelemetry::global::shutdown_tracer_provider(); 199 | 200 | Ok(()) 201 | } 202 | -------------------------------------------------------------------------------- /tarpc/rustfmt.toml: -------------------------------------------------------------------------------- 1 | edition = "2018" 2 | -------------------------------------------------------------------------------- /tarpc/src/cancellations.rs: -------------------------------------------------------------------------------- 1 | use futures::{prelude::*, task::*}; 2 | use std::pin::Pin; 3 | use tokio::sync::mpsc; 4 | 5 | /// Sends request cancellation signals. 6 | #[derive(Debug, Clone)] 7 | pub struct RequestCancellation(mpsc::UnboundedSender); 8 | 9 | /// A stream of IDs of requests that have been canceled. 10 | #[derive(Debug)] 11 | pub struct CanceledRequests(mpsc::UnboundedReceiver); 12 | 13 | /// Returns a channel to send request cancellation messages. 14 | pub fn cancellations() -> (RequestCancellation, CanceledRequests) { 15 | // Unbounded because messages are sent in the drop fn. This is fine, because it's still 16 | // bounded by the number of in-flight requests. 17 | let (tx, rx) = mpsc::unbounded_channel(); 18 | (RequestCancellation(tx), CanceledRequests(rx)) 19 | } 20 | 21 | impl RequestCancellation { 22 | /// Cancels the request with ID `request_id`. 23 | /// 24 | /// No validation is done of `request_id`. There is no way to know if the request id provided 25 | /// corresponds to a request actually tracked by the backing channel. `RequestCancellation` is 26 | /// a one-way communication channel. 27 | /// 28 | /// Once request data is cleaned up, a response will never be received by the client. This is 29 | /// useful primarily when request processing ends prematurely for requests with long deadlines 30 | /// which would otherwise continue to be tracked by the backing channel—a kind of leak. 31 | pub fn cancel(&self, request_id: u64) { 32 | let _ = self.0.send(request_id); 33 | } 34 | } 35 | 36 | impl CanceledRequests { 37 | /// Polls for a cancelled request. 38 | pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { 39 | self.0.poll_recv(cx) 40 | } 41 | } 42 | 43 | impl Stream for CanceledRequests { 44 | type Item = u64; 45 | 46 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 47 | self.poll_recv(cx) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /tarpc/src/client/in_flight_requests.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | context, 3 | util::{Compact, TimeUntil}, 4 | }; 5 | use fnv::FnvHashMap; 6 | use std::{ 7 | collections::hash_map, 8 | task::{Context, Poll}, 9 | }; 10 | use tokio::sync::oneshot; 11 | use tokio_util::time::delay_queue::{self, DelayQueue}; 12 | use tracing::Span; 13 | 14 | /// Requests already written to the wire that haven't yet received responses. 15 | #[derive(Debug)] 16 | pub struct InFlightRequests { 17 | request_data: FnvHashMap>, 18 | deadlines: DelayQueue, 19 | } 20 | 21 | impl Default for InFlightRequests { 22 | fn default() -> Self { 23 | Self { 24 | request_data: Default::default(), 25 | deadlines: Default::default(), 26 | } 27 | } 28 | } 29 | 30 | #[derive(Debug)] 31 | struct RequestData { 32 | ctx: context::Context, 33 | span: Span, 34 | response_completion: oneshot::Sender, 35 | /// The key to remove the timer for the request's deadline. 36 | deadline_key: delay_queue::Key, 37 | } 38 | 39 | /// An error returned when an attempt is made to insert a request with an ID that is already in 40 | /// use. 41 | #[derive(Debug)] 42 | pub struct AlreadyExistsError; 43 | 44 | impl InFlightRequests { 45 | /// Returns the number of in-flight requests. 46 | pub fn len(&self) -> usize { 47 | self.request_data.len() 48 | } 49 | 50 | /// Returns true iff there are no requests in flight. 51 | pub fn is_empty(&self) -> bool { 52 | self.request_data.is_empty() 53 | } 54 | 55 | /// Starts a request, unless a request with the same ID is already in flight. 56 | pub fn insert_request( 57 | &mut self, 58 | request_id: u64, 59 | ctx: context::Context, 60 | span: Span, 61 | response_completion: oneshot::Sender, 62 | ) -> Result<(), AlreadyExistsError> { 63 | match self.request_data.entry(request_id) { 64 | hash_map::Entry::Vacant(vacant) => { 65 | let timeout = ctx.deadline.time_until(); 66 | let deadline_key = self.deadlines.insert(request_id, timeout); 67 | vacant.insert(RequestData { 68 | ctx, 69 | span, 70 | response_completion, 71 | deadline_key, 72 | }); 73 | Ok(()) 74 | } 75 | hash_map::Entry::Occupied(_) => Err(AlreadyExistsError), 76 | } 77 | } 78 | 79 | /// Removes a request without aborting. Returns true iff the request was found. 80 | pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option { 81 | if let Some(request_data) = self.request_data.remove(&request_id) { 82 | self.request_data.compact(0.1); 83 | self.deadlines.remove(&request_data.deadline_key); 84 | let _ = request_data.response_completion.send(result); 85 | return Some(request_data.span); 86 | } 87 | 88 | tracing::debug!("No in-flight request found for request_id = {request_id}."); 89 | 90 | // If the response completion was absent, then the request was already canceled. 91 | None 92 | } 93 | 94 | /// Completes all requests using the provided function. 95 | /// Returns Spans for all completes requests. 96 | pub fn complete_all_requests<'a>( 97 | &'a mut self, 98 | mut result: impl FnMut() -> Res + 'a, 99 | ) -> impl Iterator + 'a { 100 | self.deadlines.clear(); 101 | self.request_data.drain().map(move |(_, request_data)| { 102 | let _ = request_data.response_completion.send(result()); 103 | request_data.span 104 | }) 105 | } 106 | 107 | /// Cancels a request without completing (typically used when a request handle was dropped 108 | /// before the request completed). 109 | pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { 110 | if let Some(request_data) = self.request_data.remove(&request_id) { 111 | self.request_data.compact(0.1); 112 | self.deadlines.remove(&request_data.deadline_key); 113 | Some((request_data.ctx, request_data.span)) 114 | } else { 115 | None 116 | } 117 | } 118 | 119 | /// Yields a request that has expired, completing it with a TimedOut error. 120 | /// The caller should send cancellation messages for any yielded request ID. 121 | pub fn poll_expired( 122 | &mut self, 123 | cx: &mut Context, 124 | expired_error: impl Fn() -> Res, 125 | ) -> Poll> { 126 | self.deadlines.poll_expired(cx).map(|expired| { 127 | let request_id = expired?.into_inner(); 128 | if let Some(request_data) = self.request_data.remove(&request_id) { 129 | let _entered = request_data.span.enter(); 130 | tracing::error!("DeadlineExceeded"); 131 | self.request_data.compact(0.1); 132 | let _ = request_data.response_completion.send(expired_error()); 133 | } 134 | Some(request_id) 135 | }) 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /tarpc/src/client/stub.rs: -------------------------------------------------------------------------------- 1 | //! Provides a Stub trait, implemented by types that can call remote services. 2 | 3 | use crate::{ 4 | client::{Channel, RpcError}, 5 | context, 6 | server::Serve, 7 | RequestName, 8 | }; 9 | 10 | pub mod load_balance; 11 | pub mod retry; 12 | 13 | #[cfg(test)] 14 | mod mock; 15 | 16 | /// A connection to a remote service. 17 | /// Calls the service with requests of type `Req` and receives responses of type `Resp`. 18 | #[allow(async_fn_in_trait)] 19 | pub trait Stub { 20 | /// The service request type. 21 | type Req: RequestName; 22 | 23 | /// The service response type. 24 | type Resp; 25 | 26 | /// Calls a remote service. 27 | async fn call(&self, ctx: context::Context, request: Self::Req) 28 | -> Result; 29 | } 30 | 31 | impl Stub for Channel 32 | where 33 | Req: RequestName, 34 | { 35 | type Req = Req; 36 | type Resp = Resp; 37 | 38 | async fn call(&self, ctx: context::Context, request: Req) -> Result { 39 | Self::call(self, ctx, request).await 40 | } 41 | } 42 | 43 | impl Stub for S 44 | where 45 | S: Serve + Clone, 46 | { 47 | type Req = S::Req; 48 | type Resp = S::Resp; 49 | async fn call(&self, ctx: context::Context, req: Self::Req) -> Result { 50 | self.clone().serve(ctx, req).await.map_err(RpcError::Server) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /tarpc/src/client/stub/load_balance.rs: -------------------------------------------------------------------------------- 1 | //! Provides load-balancing [Stubs](crate::client::stub::Stub). 2 | 3 | pub use consistent_hash::ConsistentHash; 4 | pub use round_robin::RoundRobin; 5 | 6 | /// Provides a stub that load-balances with a simple round-robin strategy. 7 | mod round_robin { 8 | use crate::{ 9 | client::{stub, RpcError}, 10 | context, 11 | }; 12 | use cycle::AtomicCycle; 13 | 14 | impl stub::Stub for RoundRobin 15 | where 16 | Stub: stub::Stub, 17 | { 18 | type Req = Stub::Req; 19 | type Resp = Stub::Resp; 20 | 21 | async fn call( 22 | &self, 23 | ctx: context::Context, 24 | request: Self::Req, 25 | ) -> Result { 26 | let next = self.stubs.next(); 27 | next.call(ctx, request).await 28 | } 29 | } 30 | 31 | /// A Stub that load-balances across backing stubs by round robin. 32 | #[derive(Clone, Debug)] 33 | pub struct RoundRobin { 34 | stubs: AtomicCycle, 35 | } 36 | 37 | impl RoundRobin 38 | where 39 | Stub: stub::Stub, 40 | { 41 | /// Returns a new RoundRobin stub. 42 | pub fn new(stubs: Vec) -> Self { 43 | Self { 44 | stubs: AtomicCycle::new(stubs), 45 | } 46 | } 47 | } 48 | 49 | mod cycle { 50 | use std::sync::{ 51 | atomic::{AtomicUsize, Ordering}, 52 | Arc, 53 | }; 54 | 55 | /// Cycles endlessly and atomically over a collection of elements of type T. 56 | #[derive(Clone, Debug)] 57 | pub struct AtomicCycle(Arc>); 58 | 59 | #[derive(Debug)] 60 | struct State { 61 | elements: Vec, 62 | next: AtomicUsize, 63 | } 64 | 65 | impl AtomicCycle { 66 | pub fn new(elements: Vec) -> Self { 67 | Self(Arc::new(State { 68 | elements, 69 | next: Default::default(), 70 | })) 71 | } 72 | 73 | pub fn next(&self) -> &T { 74 | self.0.next() 75 | } 76 | } 77 | 78 | impl State { 79 | pub fn next(&self) -> &T { 80 | let next = self.next.fetch_add(1, Ordering::Relaxed); 81 | &self.elements[next % self.elements.len()] 82 | } 83 | } 84 | 85 | #[test] 86 | fn test_cycle() { 87 | let cycle = AtomicCycle::new(vec![1, 2, 3]); 88 | assert_eq!(cycle.next(), &1); 89 | assert_eq!(cycle.next(), &2); 90 | assert_eq!(cycle.next(), &3); 91 | assert_eq!(cycle.next(), &1); 92 | } 93 | } 94 | } 95 | 96 | /// Provides a stub that load-balances with a consistent hashing strategy. 97 | /// 98 | /// Each request is hashed, then mapped to a stub based on the hash. Equivalent requests will use 99 | /// the same stub. 100 | mod consistent_hash { 101 | use crate::{ 102 | client::{stub, RpcError}, 103 | context, 104 | }; 105 | use std::{ 106 | collections::hash_map::RandomState, 107 | hash::{BuildHasher, Hash, Hasher}, 108 | num::TryFromIntError, 109 | }; 110 | 111 | impl stub::Stub for ConsistentHash 112 | where 113 | Stub: stub::Stub, 114 | Stub::Req: Hash, 115 | S: BuildHasher, 116 | { 117 | type Req = Stub::Req; 118 | type Resp = Stub::Resp; 119 | 120 | async fn call( 121 | &self, 122 | ctx: context::Context, 123 | request: Self::Req, 124 | ) -> Result { 125 | let index = usize::try_from(self.hash_request(&request) % self.stubs_len).expect( 126 | "invariant broken: stubs_len is not larger than a usize, \ 127 | so the hash modulo stubs_len should always fit in a usize", 128 | ); 129 | let next = &self.stubs[index]; 130 | next.call(ctx, request).await 131 | } 132 | } 133 | 134 | /// A Stub that load-balances across backing stubs by round robin. 135 | #[derive(Clone, Debug)] 136 | pub struct ConsistentHash { 137 | stubs: Vec, 138 | stubs_len: u64, 139 | hasher: S, 140 | } 141 | 142 | impl ConsistentHash 143 | where 144 | Stub: stub::Stub, 145 | Stub::Req: Hash, 146 | { 147 | /// Returns a new RoundRobin stub. 148 | /// Returns an err if the length of `stubs` overflows a u64. 149 | pub fn new(stubs: Vec) -> Result { 150 | Ok(Self { 151 | stubs_len: stubs.len().try_into()?, 152 | stubs, 153 | hasher: RandomState::new(), 154 | }) 155 | } 156 | } 157 | 158 | impl ConsistentHash 159 | where 160 | Stub: stub::Stub, 161 | Stub::Req: Hash, 162 | S: BuildHasher, 163 | { 164 | /// Returns a new RoundRobin stub. 165 | /// Returns an err if the length of `stubs` overflows a u64. 166 | pub fn with_hasher(stubs: Vec, hasher: S) -> Result { 167 | Ok(Self { 168 | stubs_len: stubs.len().try_into()?, 169 | stubs, 170 | hasher, 171 | }) 172 | } 173 | 174 | fn hash_request(&self, req: &Stub::Req) -> u64 { 175 | let mut hasher = self.hasher.build_hasher(); 176 | req.hash(&mut hasher); 177 | hasher.finish() 178 | } 179 | } 180 | 181 | #[cfg(test)] 182 | mod tests { 183 | use super::ConsistentHash; 184 | use crate::{ 185 | client::stub::{mock::Mock, Stub}, 186 | context, 187 | }; 188 | use std::{ 189 | collections::HashMap, 190 | hash::{BuildHasher, Hash, Hasher}, 191 | rc::Rc, 192 | }; 193 | 194 | #[tokio::test] 195 | async fn test() -> anyhow::Result<()> { 196 | let stub = ConsistentHash::<_, FakeHasherBuilder>::with_hasher( 197 | vec![ 198 | // For easier reading of the assertions made in this test, each Mock's response 199 | // value is equal to a hash value that should map to its index: 3 % 3 = 0, 1 % 200 | // 3 = 1, etc. 201 | Mock::new([('a', 3), ('b', 3), ('c', 3)]), 202 | Mock::new([('a', 1), ('b', 1), ('c', 1)]), 203 | Mock::new([('a', 2), ('b', 2), ('c', 2)]), 204 | ], 205 | FakeHasherBuilder::new([('a', 1), ('b', 2), ('c', 3)]), 206 | )?; 207 | 208 | for _ in 0..2 { 209 | let resp = stub.call(context::current(), 'a').await?; 210 | assert_eq!(resp, 1); 211 | 212 | let resp = stub.call(context::current(), 'b').await?; 213 | assert_eq!(resp, 2); 214 | 215 | let resp = stub.call(context::current(), 'c').await?; 216 | assert_eq!(resp, 3); 217 | } 218 | 219 | Ok(()) 220 | } 221 | 222 | struct HashRecorder(Vec); 223 | impl Hasher for HashRecorder { 224 | fn write(&mut self, bytes: &[u8]) { 225 | self.0 = Vec::from(bytes); 226 | } 227 | fn finish(&self) -> u64 { 228 | 0 229 | } 230 | } 231 | 232 | struct FakeHasherBuilder { 233 | recorded_hashes: Rc, u64>>, 234 | } 235 | 236 | struct FakeHasher { 237 | recorded_hashes: Rc, u64>>, 238 | output: u64, 239 | } 240 | 241 | impl BuildHasher for FakeHasherBuilder { 242 | type Hasher = FakeHasher; 243 | 244 | fn build_hasher(&self) -> Self::Hasher { 245 | FakeHasher { 246 | recorded_hashes: self.recorded_hashes.clone(), 247 | output: 0, 248 | } 249 | } 250 | } 251 | 252 | impl FakeHasherBuilder { 253 | fn new(fake_hashes: [(T, u64); N]) -> Self { 254 | let mut recorded_hashes = HashMap::new(); 255 | for (to_hash, fake_hash) in fake_hashes { 256 | let mut recorder = HashRecorder(vec![]); 257 | to_hash.hash(&mut recorder); 258 | recorded_hashes.insert(recorder.0, fake_hash); 259 | } 260 | Self { 261 | recorded_hashes: Rc::new(recorded_hashes), 262 | } 263 | } 264 | } 265 | 266 | impl Hasher for FakeHasher { 267 | fn write(&mut self, bytes: &[u8]) { 268 | if let Some(hash) = self.recorded_hashes.get(bytes) { 269 | self.output = *hash; 270 | } 271 | } 272 | fn finish(&self) -> u64 { 273 | self.output 274 | } 275 | } 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /tarpc/src/client/stub/mock.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | client::{stub::Stub, RpcError}, 3 | context, RequestName, ServerError, 4 | }; 5 | use std::{collections::HashMap, hash::Hash, io}; 6 | 7 | /// A mock stub that returns user-specified responses. 8 | pub struct Mock { 9 | responses: HashMap, 10 | } 11 | 12 | impl Mock 13 | where 14 | Req: Eq + Hash, 15 | { 16 | /// Returns a new mock, mocking the specified (request, response) pairs. 17 | pub fn new(responses: [(Req, Resp); N]) -> Self { 18 | Self { 19 | responses: HashMap::from(responses), 20 | } 21 | } 22 | } 23 | 24 | impl Stub for Mock 25 | where 26 | Req: Eq + Hash + RequestName, 27 | Resp: Clone, 28 | { 29 | type Req = Req; 30 | type Resp = Resp; 31 | 32 | async fn call(&self, _: context::Context, request: Self::Req) -> Result { 33 | self.responses 34 | .get(&request) 35 | .cloned() 36 | .map(Ok) 37 | .unwrap_or_else(|| { 38 | Err(RpcError::Server(ServerError { 39 | kind: io::ErrorKind::NotFound, 40 | detail: "mock (request, response) entry not found".into(), 41 | })) 42 | }) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /tarpc/src/client/stub/retry.rs: -------------------------------------------------------------------------------- 1 | //! Provides a stub that retries requests based on response contents.. 2 | 3 | use crate::{ 4 | client::{stub, RpcError}, 5 | context, RequestName, 6 | }; 7 | use std::sync::Arc; 8 | 9 | impl stub::Stub for Retry 10 | where 11 | Req: RequestName, 12 | Stub: stub::Stub>, 13 | F: Fn(&Result, u32) -> bool, 14 | { 15 | type Req = Req; 16 | type Resp = Stub::Resp; 17 | 18 | async fn call( 19 | &self, 20 | ctx: context::Context, 21 | request: Self::Req, 22 | ) -> Result { 23 | let request = Arc::new(request); 24 | for i in 1.. { 25 | let result = self.stub.call(ctx, Arc::clone(&request)).await; 26 | if (self.should_retry)(&result, i) { 27 | tracing::trace!("Retrying on attempt {i}"); 28 | continue; 29 | } 30 | return result; 31 | } 32 | unreachable!("Wow, that was a lot of attempts!"); 33 | } 34 | } 35 | 36 | /// A Stub that retries requests based on response contents. 37 | /// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled. 38 | #[derive(Clone, Debug)] 39 | pub struct Retry { 40 | should_retry: F, 41 | stub: Stub, 42 | } 43 | 44 | impl Retry 45 | where 46 | Stub: stub::Stub>, 47 | F: Fn(&Result, u32) -> bool, 48 | { 49 | /// Creates a new Retry stub that delegates calls to the underlying `stub`. 50 | pub fn new(stub: Stub, should_retry: F) -> Self { 51 | Self { stub, should_retry } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /tarpc/src/context.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | //! Provides a request context that carries a deadline and trace context. This context is sent from 8 | //! client to server and is used by the server to enforce response deadlines. 9 | 10 | use crate::trace::{self, TraceId}; 11 | use opentelemetry::trace::TraceContextExt; 12 | use static_assertions::assert_impl_all; 13 | use std::{ 14 | convert::TryFrom, 15 | time::{Duration, Instant}, 16 | }; 17 | use tracing_opentelemetry::OpenTelemetrySpanExt; 18 | 19 | /// A request context that carries request-scoped information like deadlines and trace information. 20 | /// It is sent from client to server and is used by the server to enforce response deadlines. 21 | /// 22 | /// The context should not be stored directly in a server implementation, because the context will 23 | /// be different for each request in scope. 24 | #[derive(Clone, Copy, Debug)] 25 | #[non_exhaustive] 26 | #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] 27 | pub struct Context { 28 | /// When the client expects the request to be complete by. The server should cancel the request 29 | /// if it is not complete by this time. 30 | #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] 31 | // Serialized as a Duration to prevent clock skew issues. 32 | #[cfg_attr(feature = "serde1", serde(with = "absolute_to_relative_time"))] 33 | pub deadline: Instant, 34 | /// Uniquely identifies requests originating from the same source. 35 | /// When a service handles a request by making requests itself, those requests should 36 | /// include the same `trace_id` as that included on the original request. This way, 37 | /// users can trace related actions across a distributed system. 38 | pub trace_context: trace::Context, 39 | } 40 | 41 | #[cfg(feature = "serde1")] 42 | mod absolute_to_relative_time { 43 | pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; 44 | pub use std::time::{Duration, Instant}; 45 | 46 | pub fn serialize(deadline: &Instant, serializer: S) -> Result 47 | where 48 | S: Serializer, 49 | { 50 | let deadline = deadline.duration_since(Instant::now()); 51 | deadline.serialize(serializer) 52 | } 53 | 54 | pub fn deserialize<'de, D>(deserializer: D) -> Result 55 | where 56 | D: Deserializer<'de>, 57 | { 58 | let deadline = Duration::deserialize(deserializer)?; 59 | Ok(Instant::now() + deadline) 60 | } 61 | 62 | #[cfg(test)] 63 | #[derive(serde::Serialize, serde::Deserialize)] 64 | struct AbsoluteToRelative(#[serde(with = "self")] Instant); 65 | 66 | #[test] 67 | fn test_serialize() { 68 | let now = Instant::now(); 69 | let deadline = now + Duration::from_secs(10); 70 | let serialized_deadline = bincode::serialize(&AbsoluteToRelative(deadline)).unwrap(); 71 | let deserialized_deadline: Duration = bincode::deserialize(&serialized_deadline).unwrap(); 72 | // TODO: how to avoid flakiness? 73 | assert!(deserialized_deadline > Duration::from_secs(9)); 74 | } 75 | 76 | #[test] 77 | fn test_deserialize() { 78 | let deadline = Duration::from_secs(10); 79 | let serialized_deadline = bincode::serialize(&deadline).unwrap(); 80 | let AbsoluteToRelative(deserialized_deadline) = 81 | bincode::deserialize(&serialized_deadline).unwrap(); 82 | // TODO: how to avoid flakiness? 83 | assert!(deserialized_deadline > Instant::now() + Duration::from_secs(9)); 84 | } 85 | } 86 | 87 | assert_impl_all!(Context: Send, Sync); 88 | 89 | fn ten_seconds_from_now() -> Instant { 90 | Instant::now() + Duration::from_secs(10) 91 | } 92 | 93 | /// Returns the context for the current request, or a default Context if no request is active. 94 | pub fn current() -> Context { 95 | Context::current() 96 | } 97 | 98 | #[derive(Clone)] 99 | struct Deadline(Instant); 100 | 101 | impl Default for Deadline { 102 | fn default() -> Self { 103 | Self(ten_seconds_from_now()) 104 | } 105 | } 106 | 107 | impl Context { 108 | /// Returns the context for the current request, or a default Context if no request is active. 109 | pub fn current() -> Self { 110 | let span = tracing::Span::current(); 111 | Self { 112 | trace_context: trace::Context::try_from(&span) 113 | .unwrap_or_else(|_| trace::Context::default()), 114 | deadline: span 115 | .context() 116 | .get::() 117 | .cloned() 118 | .unwrap_or_default() 119 | .0, 120 | } 121 | } 122 | 123 | /// Returns the ID of the request-scoped trace. 124 | pub fn trace_id(&self) -> &TraceId { 125 | &self.trace_context.trace_id 126 | } 127 | } 128 | 129 | /// An extension trait for [`tracing::Span`] for propagating tarpc Contexts. 130 | pub(crate) trait SpanExt { 131 | /// Sets the given context on this span. Newly-created spans will be children of the given 132 | /// context's trace context. 133 | fn set_context(&self, context: &Context); 134 | } 135 | 136 | impl SpanExt for tracing::Span { 137 | fn set_context(&self, context: &Context) { 138 | self.set_parent( 139 | opentelemetry::Context::new() 140 | .with_remote_span_context(opentelemetry::trace::SpanContext::new( 141 | opentelemetry::trace::TraceId::from(context.trace_context.trace_id), 142 | opentelemetry::trace::SpanId::from(context.trace_context.span_id), 143 | opentelemetry::trace::TraceFlags::from(context.trace_context.sampling_decision), 144 | true, 145 | opentelemetry::trace::TraceState::default(), 146 | )) 147 | .with_value(Deadline(context.deadline)), 148 | ); 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /tarpc/src/server/in_flight_requests.rs: -------------------------------------------------------------------------------- 1 | use crate::util::{Compact, TimeUntil}; 2 | use fnv::FnvHashMap; 3 | use futures::future::{AbortHandle, AbortRegistration}; 4 | use std::{ 5 | collections::hash_map, 6 | task::{Context, Poll}, 7 | time::Instant, 8 | }; 9 | use tokio_util::time::delay_queue::{self, DelayQueue}; 10 | use tracing::Span; 11 | 12 | /// A data structure that tracks in-flight requests. It aborts requests, 13 | /// either on demand or when a request deadline expires. 14 | #[derive(Debug, Default)] 15 | pub struct InFlightRequests { 16 | request_data: FnvHashMap, 17 | deadlines: DelayQueue, 18 | } 19 | 20 | /// Data needed to clean up a single in-flight request. 21 | #[derive(Debug)] 22 | struct RequestData { 23 | /// Aborts the response handler for the associated request. 24 | abort_handle: AbortHandle, 25 | /// The key to remove the timer for the request's deadline. 26 | deadline_key: delay_queue::Key, 27 | /// The client span. 28 | span: Span, 29 | } 30 | 31 | /// An error returned when a request attempted to start with the same ID as a request already 32 | /// in flight. 33 | #[derive(Debug)] 34 | pub struct AlreadyExistsError; 35 | 36 | impl InFlightRequests { 37 | /// Returns the number of in-flight requests. 38 | pub fn len(&self) -> usize { 39 | self.request_data.len() 40 | } 41 | 42 | /// Starts a request, unless a request with the same ID is already in flight. 43 | pub fn start_request( 44 | &mut self, 45 | request_id: u64, 46 | deadline: Instant, 47 | span: Span, 48 | ) -> Result { 49 | match self.request_data.entry(request_id) { 50 | hash_map::Entry::Vacant(vacant) => { 51 | let timeout = deadline.time_until(); 52 | let (abort_handle, abort_registration) = AbortHandle::new_pair(); 53 | let deadline_key = self.deadlines.insert(request_id, timeout); 54 | vacant.insert(RequestData { 55 | abort_handle, 56 | deadline_key, 57 | span, 58 | }); 59 | Ok(abort_registration) 60 | } 61 | hash_map::Entry::Occupied(_) => Err(AlreadyExistsError), 62 | } 63 | } 64 | 65 | /// Cancels an in-flight request. Returns true iff the request was found. 66 | pub fn cancel_request(&mut self, request_id: u64) -> bool { 67 | if let Some(RequestData { 68 | span, 69 | abort_handle, 70 | deadline_key, 71 | }) = self.request_data.remove(&request_id) 72 | { 73 | let _entered = span.enter(); 74 | self.request_data.compact(0.1); 75 | abort_handle.abort(); 76 | self.deadlines.remove(&deadline_key); 77 | tracing::info!("ReceiveCancel"); 78 | true 79 | } else { 80 | false 81 | } 82 | } 83 | 84 | /// Removes a request without aborting. Returns true iff the request was found. 85 | /// This method should be used when a response is being sent. 86 | pub fn remove_request(&mut self, request_id: u64) -> Option { 87 | if let Some(request_data) = self.request_data.remove(&request_id) { 88 | self.request_data.compact(0.1); 89 | self.deadlines.remove(&request_data.deadline_key); 90 | Some(request_data.span) 91 | } else { 92 | None 93 | } 94 | } 95 | 96 | /// Yields a request that has expired, aborting any ongoing processing of that request. 97 | pub fn poll_expired(&mut self, cx: &mut Context) -> Poll> { 98 | if self.deadlines.is_empty() { 99 | // TODO(https://github.com/tokio-rs/tokio/issues/4161) 100 | // This is a workaround for DelayQueue not always treating this case correctly. 101 | return Poll::Ready(None); 102 | } 103 | self.deadlines.poll_expired(cx).map(|expired| { 104 | let expired = expired?; 105 | if let Some(RequestData { 106 | abort_handle, span, .. 107 | }) = self.request_data.remove(expired.get_ref()) 108 | { 109 | let _entered = span.enter(); 110 | self.request_data.compact(0.1); 111 | abort_handle.abort(); 112 | tracing::error!("DeadlineExceeded"); 113 | } 114 | Some(expired.into_inner()) 115 | }) 116 | } 117 | } 118 | 119 | /// When InFlightRequests is dropped, any outstanding requests are aborted. 120 | impl Drop for InFlightRequests { 121 | fn drop(&mut self) { 122 | self.request_data 123 | .values() 124 | .for_each(|request_data| request_data.abort_handle.abort()) 125 | } 126 | } 127 | 128 | #[cfg(test)] 129 | mod tests { 130 | use super::*; 131 | 132 | use assert_matches::assert_matches; 133 | use futures::{ 134 | future::{pending, Abortable}, 135 | FutureExt, 136 | }; 137 | use futures_test::task::noop_context; 138 | 139 | #[tokio::test] 140 | async fn start_request_increases_len() { 141 | let mut in_flight_requests = InFlightRequests::default(); 142 | assert_eq!(in_flight_requests.len(), 0); 143 | in_flight_requests 144 | .start_request(0, Instant::now(), Span::current()) 145 | .unwrap(); 146 | assert_eq!(in_flight_requests.len(), 1); 147 | } 148 | 149 | #[tokio::test] 150 | async fn polling_expired_aborts() { 151 | let mut in_flight_requests = InFlightRequests::default(); 152 | let abort_registration = in_flight_requests 153 | .start_request(0, Instant::now(), Span::current()) 154 | .unwrap(); 155 | let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); 156 | 157 | tokio::time::pause(); 158 | tokio::time::advance(std::time::Duration::from_secs(1000)).await; 159 | 160 | assert_matches!( 161 | in_flight_requests.poll_expired(&mut noop_context()), 162 | Poll::Ready(Some(_)) 163 | ); 164 | assert_matches!( 165 | abortable_future.poll_unpin(&mut noop_context()), 166 | Poll::Ready(Err(_)) 167 | ); 168 | assert_eq!(in_flight_requests.len(), 0); 169 | } 170 | 171 | #[tokio::test] 172 | async fn cancel_request_aborts() { 173 | let mut in_flight_requests = InFlightRequests::default(); 174 | let abort_registration = in_flight_requests 175 | .start_request(0, Instant::now(), Span::current()) 176 | .unwrap(); 177 | let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); 178 | 179 | assert!(in_flight_requests.cancel_request(0)); 180 | assert_matches!( 181 | abortable_future.poll_unpin(&mut noop_context()), 182 | Poll::Ready(Err(_)) 183 | ); 184 | assert_eq!(in_flight_requests.len(), 0); 185 | } 186 | 187 | #[tokio::test] 188 | async fn remove_request_doesnt_abort() { 189 | let mut in_flight_requests = InFlightRequests::default(); 190 | assert!(in_flight_requests.deadlines.is_empty()); 191 | 192 | let abort_registration = in_flight_requests 193 | .start_request( 194 | 0, 195 | Instant::now() + std::time::Duration::from_secs(10), 196 | Span::current(), 197 | ) 198 | .unwrap(); 199 | let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); 200 | 201 | // Precondition: Pending expiration 202 | assert_matches!( 203 | in_flight_requests.poll_expired(&mut noop_context()), 204 | Poll::Pending 205 | ); 206 | assert!(!in_flight_requests.deadlines.is_empty()); 207 | 208 | assert_matches!(in_flight_requests.remove_request(0), Some(_)); 209 | // Postcondition: No pending expirations 210 | assert!(in_flight_requests.deadlines.is_empty()); 211 | assert_matches!( 212 | in_flight_requests.poll_expired(&mut noop_context()), 213 | Poll::Ready(None) 214 | ); 215 | assert_matches!( 216 | abortable_future.poll_unpin(&mut noop_context()), 217 | Poll::Pending 218 | ); 219 | assert_eq!(in_flight_requests.len(), 0); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /tarpc/src/server/incoming.rs: -------------------------------------------------------------------------------- 1 | use super::{ 2 | limits::{channels_per_key::MaxChannelsPerKey, requests_per_channel::MaxRequestsPerChannel}, 3 | Channel, RequestName, Serve, 4 | }; 5 | use futures::prelude::*; 6 | use std::{fmt, hash::Hash}; 7 | 8 | /// An extension trait for [streams](futures::prelude::Stream) of [`Channels`](Channel). 9 | pub trait Incoming 10 | where 11 | Self: Sized + Stream, 12 | C: Channel, 13 | { 14 | /// Enforces channel per-key limits. 15 | fn max_channels_per_key(self, n: u32, keymaker: KF) -> MaxChannelsPerKey 16 | where 17 | K: fmt::Display + Eq + Hash + Clone + Unpin, 18 | KF: Fn(&C) -> K, 19 | { 20 | MaxChannelsPerKey::new(self, n, keymaker) 21 | } 22 | 23 | /// Caps the number of concurrent requests per channel. 24 | fn max_concurrent_requests_per_channel(self, n: usize) -> MaxRequestsPerChannel { 25 | MaxRequestsPerChannel::new(self, n) 26 | } 27 | 28 | /// Returns a stream of channels in execution. Each channel in execution is a stream of 29 | /// futures, where each future is an in-flight request being rsponded to. 30 | fn execute( 31 | self, 32 | serve: S, 33 | ) -> impl Stream>> 34 | where 35 | C::Req: RequestName, 36 | S: Serve + Clone, 37 | { 38 | self.map(move |channel| channel.execute(serve.clone())) 39 | } 40 | } 41 | 42 | #[cfg(feature = "tokio1")] 43 | /// Spawns all channels-in-execution, delegating to the tokio runtime to manage their completion. 44 | /// Each channel is spawned, and each request from each channel is spawned. 45 | /// Note that this function is generic over any stream-of-streams-of-futures, but it is intended 46 | /// for spawning streams of channels. 47 | /// 48 | /// # Example 49 | /// ```rust 50 | /// use tarpc::{ 51 | /// context, 52 | /// client::{self, NewClient}, 53 | /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, 54 | /// transport, 55 | /// }; 56 | /// use futures::prelude::*; 57 | /// 58 | /// #[tokio::main] 59 | /// async fn main() { 60 | /// let (tx, rx) = transport::channel::unbounded(); 61 | /// let NewClient { client, dispatch } = client::new(client::Config::default(), tx); 62 | /// tokio::spawn(dispatch); 63 | /// 64 | /// let incoming = stream::once(async move { 65 | /// BaseChannel::new(server::Config::default(), rx) 66 | /// }).execute(serve(|_, i| async move { Ok(i + 1) })); 67 | /// tokio::spawn(spawn_incoming(incoming)); 68 | /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); 69 | /// } 70 | /// ``` 71 | pub async fn spawn_incoming( 72 | incoming: impl Stream< 73 | Item = impl Stream + Send + 'static> + Send + 'static, 74 | >, 75 | ) { 76 | use futures::pin_mut; 77 | pin_mut!(incoming); 78 | while let Some(channel) = incoming.next().await { 79 | tokio::spawn(async move { 80 | pin_mut!(channel); 81 | while let Some(request) = channel.next().await { 82 | tokio::spawn(request); 83 | } 84 | }); 85 | } 86 | } 87 | 88 | impl Incoming for S 89 | where 90 | S: Sized + Stream, 91 | C: Channel, 92 | { 93 | } 94 | -------------------------------------------------------------------------------- /tarpc/src/server/limits.rs: -------------------------------------------------------------------------------- 1 | /// Provides functionality to limit the number of active channels. 2 | pub mod channels_per_key; 3 | 4 | /// Provides a [channel](crate::server::Channel) that limits the number of in-flight requests. 5 | pub mod requests_per_channel; 6 | -------------------------------------------------------------------------------- /tarpc/src/server/limits/channels_per_key.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use crate::{ 8 | server::{self, Channel}, 9 | util::Compact, 10 | }; 11 | use fnv::FnvHashMap; 12 | use futures::{prelude::*, ready, stream::Fuse, task::*}; 13 | use pin_project::pin_project; 14 | use std::sync::{Arc, Weak}; 15 | use std::{ 16 | collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin, 17 | }; 18 | use tokio::sync::mpsc; 19 | use tracing::{debug, info, trace}; 20 | 21 | /// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on 22 | /// per-key limits. 23 | /// 24 | /// The decision to drop a Channel is made once at the time the Channel materializes. Once a 25 | /// Channel is yielded, it will not be prematurely dropped. 26 | #[pin_project] 27 | #[derive(Debug)] 28 | pub struct MaxChannelsPerKey 29 | where 30 | K: Eq + Hash, 31 | { 32 | #[pin] 33 | listener: Fuse, 34 | channels_per_key: u32, 35 | dropped_keys: mpsc::UnboundedReceiver, 36 | dropped_keys_tx: mpsc::UnboundedSender, 37 | key_counts: FnvHashMap>>, 38 | keymaker: F, 39 | } 40 | 41 | /// A channel that is tracked by [`MaxChannelsPerKey`]. 42 | #[pin_project] 43 | #[derive(Debug)] 44 | pub struct TrackedChannel { 45 | #[pin] 46 | inner: C, 47 | tracker: Arc>, 48 | } 49 | 50 | #[derive(Debug)] 51 | struct Tracker { 52 | key: Option, 53 | dropped_keys: mpsc::UnboundedSender, 54 | } 55 | 56 | impl Drop for Tracker { 57 | fn drop(&mut self) { 58 | // Don't care if the listener is dropped. 59 | let _ = self.dropped_keys.send(self.key.take().unwrap()); 60 | } 61 | } 62 | 63 | impl Stream for TrackedChannel 64 | where 65 | C: Stream, 66 | { 67 | type Item = ::Item; 68 | 69 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 70 | self.inner_pin_mut().poll_next(cx) 71 | } 72 | } 73 | 74 | impl Sink for TrackedChannel 75 | where 76 | C: Sink, 77 | { 78 | type Error = C::Error; 79 | 80 | fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 81 | self.inner_pin_mut().poll_ready(cx) 82 | } 83 | 84 | fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { 85 | self.inner_pin_mut().start_send(item) 86 | } 87 | 88 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 89 | self.inner_pin_mut().poll_flush(cx) 90 | } 91 | 92 | fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 93 | self.inner_pin_mut().poll_close(cx) 94 | } 95 | } 96 | 97 | impl AsRef for TrackedChannel { 98 | fn as_ref(&self) -> &C { 99 | &self.inner 100 | } 101 | } 102 | 103 | impl Channel for TrackedChannel 104 | where 105 | C: Channel, 106 | { 107 | type Req = C::Req; 108 | type Resp = C::Resp; 109 | type Transport = C::Transport; 110 | 111 | fn config(&self) -> &server::Config { 112 | self.inner.config() 113 | } 114 | 115 | fn in_flight_requests(&self) -> usize { 116 | self.inner.in_flight_requests() 117 | } 118 | 119 | fn transport(&self) -> &Self::Transport { 120 | self.inner.transport() 121 | } 122 | } 123 | 124 | impl TrackedChannel { 125 | /// Returns the inner channel. 126 | pub fn get_ref(&self) -> &C { 127 | &self.inner 128 | } 129 | 130 | /// Returns the pinned inner channel. 131 | fn inner_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> { 132 | self.as_mut().project().inner 133 | } 134 | } 135 | 136 | impl MaxChannelsPerKey 137 | where 138 | K: Eq + Hash, 139 | S: Stream, 140 | F: Fn(&S::Item) -> K, 141 | { 142 | /// Sheds new channels to stay under configured limits. 143 | pub(crate) fn new(listener: S, channels_per_key: u32, keymaker: F) -> Self { 144 | let (dropped_keys_tx, dropped_keys) = mpsc::unbounded_channel(); 145 | MaxChannelsPerKey { 146 | listener: listener.fuse(), 147 | channels_per_key, 148 | dropped_keys, 149 | dropped_keys_tx, 150 | key_counts: FnvHashMap::default(), 151 | keymaker, 152 | } 153 | } 154 | } 155 | 156 | impl MaxChannelsPerKey 157 | where 158 | S: Stream, 159 | K: fmt::Display + Eq + Hash + Clone + Unpin, 160 | F: Fn(&S::Item) -> K, 161 | { 162 | fn listener_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse> { 163 | self.as_mut().project().listener 164 | } 165 | 166 | fn handle_new_channel( 167 | mut self: Pin<&mut Self>, 168 | stream: S::Item, 169 | ) -> Result, K> { 170 | let key = (self.as_mut().keymaker)(&stream); 171 | let tracker = self.as_mut().increment_channels_for_key(key.clone())?; 172 | 173 | trace!( 174 | channel_filter_key = %key, 175 | open_channels = Arc::strong_count(&tracker), 176 | max_open_channels = self.channels_per_key, 177 | "Opening channel"); 178 | 179 | Ok(TrackedChannel { 180 | tracker, 181 | inner: stream, 182 | }) 183 | } 184 | 185 | fn increment_channels_for_key(self: Pin<&mut Self>, key: K) -> Result>, K> { 186 | let self_ = self.project(); 187 | let dropped_keys = self_.dropped_keys_tx; 188 | match self_.key_counts.entry(key.clone()) { 189 | Entry::Vacant(vacant) => { 190 | let tracker = Arc::new(Tracker { 191 | key: Some(key), 192 | dropped_keys: dropped_keys.clone(), 193 | }); 194 | 195 | vacant.insert(Arc::downgrade(&tracker)); 196 | Ok(tracker) 197 | } 198 | Entry::Occupied(mut o) => { 199 | let count = o.get().strong_count(); 200 | if count >= usize::try_from(*self_.channels_per_key).unwrap() { 201 | info!( 202 | channel_filter_key = %key, 203 | open_channels = count, 204 | max_open_channels = *self_.channels_per_key, 205 | "At open channel limit"); 206 | Err(key) 207 | } else { 208 | Ok(o.get().upgrade().unwrap_or_else(|| { 209 | let tracker = Arc::new(Tracker { 210 | key: Some(key), 211 | dropped_keys: dropped_keys.clone(), 212 | }); 213 | 214 | *o.get_mut() = Arc::downgrade(&tracker); 215 | tracker 216 | })) 217 | } 218 | } 219 | } 220 | } 221 | 222 | fn poll_listener( 223 | mut self: Pin<&mut Self>, 224 | cx: &mut Context<'_>, 225 | ) -> Poll, K>>> { 226 | match ready!(self.listener_pin_mut().poll_next_unpin(cx)) { 227 | Some(codec) => Poll::Ready(Some(self.handle_new_channel(codec))), 228 | None => Poll::Ready(None), 229 | } 230 | } 231 | 232 | fn poll_closed_channels(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { 233 | let self_ = self.project(); 234 | match ready!(self_.dropped_keys.poll_recv(cx)) { 235 | Some(key) => { 236 | debug!( 237 | channel_filter_key = %key, 238 | "All channels dropped"); 239 | self_.key_counts.remove(&key); 240 | self_.key_counts.compact(0.1); 241 | Poll::Ready(()) 242 | } 243 | None => unreachable!("Holding a copy of closed_channels and didn't close it."), 244 | } 245 | } 246 | } 247 | 248 | impl Stream for MaxChannelsPerKey 249 | where 250 | S: Stream, 251 | K: fmt::Display + Eq + Hash + Clone + Unpin, 252 | F: Fn(&S::Item) -> K, 253 | { 254 | type Item = TrackedChannel; 255 | 256 | fn poll_next( 257 | mut self: Pin<&mut Self>, 258 | cx: &mut Context<'_>, 259 | ) -> Poll>> { 260 | loop { 261 | match ( 262 | self.as_mut().poll_listener(cx), 263 | self.as_mut().poll_closed_channels(cx), 264 | ) { 265 | (Poll::Ready(Some(Ok(channel))), _) => { 266 | return Poll::Ready(Some(channel)); 267 | } 268 | (Poll::Ready(Some(Err(_))), _) => { 269 | continue; 270 | } 271 | (_, Poll::Ready(())) => continue, 272 | (Poll::Pending, Poll::Pending) => return Poll::Pending, 273 | (Poll::Ready(None), Poll::Pending) => { 274 | trace!("Shutting down listener."); 275 | return Poll::Ready(None); 276 | } 277 | } 278 | } 279 | } 280 | } 281 | #[cfg(test)] 282 | fn ctx() -> Context<'static> { 283 | use futures::task::*; 284 | 285 | Context::from_waker(noop_waker_ref()) 286 | } 287 | 288 | #[test] 289 | fn tracker_drop() { 290 | use assert_matches::assert_matches; 291 | 292 | let (tx, mut rx) = mpsc::unbounded_channel(); 293 | Tracker { 294 | key: Some(1), 295 | dropped_keys: tx, 296 | }; 297 | assert_matches!(rx.poll_recv(&mut ctx()), Poll::Ready(Some(1))); 298 | } 299 | 300 | #[test] 301 | fn tracked_channel_stream() { 302 | use assert_matches::assert_matches; 303 | use pin_utils::pin_mut; 304 | 305 | let (chan_tx, chan) = futures::channel::mpsc::unbounded(); 306 | let (dropped_keys, _) = mpsc::unbounded_channel(); 307 | let channel = TrackedChannel { 308 | inner: chan, 309 | tracker: Arc::new(Tracker { 310 | key: Some(1), 311 | dropped_keys, 312 | }), 313 | }; 314 | 315 | chan_tx.unbounded_send("test").unwrap(); 316 | pin_mut!(channel); 317 | assert_matches!(channel.poll_next(&mut ctx()), Poll::Ready(Some("test"))); 318 | } 319 | 320 | #[test] 321 | fn tracked_channel_sink() { 322 | use assert_matches::assert_matches; 323 | use pin_utils::pin_mut; 324 | 325 | let (chan, mut chan_rx) = futures::channel::mpsc::unbounded(); 326 | let (dropped_keys, _) = mpsc::unbounded_channel(); 327 | let channel = TrackedChannel { 328 | inner: chan, 329 | tracker: Arc::new(Tracker { 330 | key: Some(1), 331 | dropped_keys, 332 | }), 333 | }; 334 | 335 | pin_mut!(channel); 336 | assert_matches!(channel.as_mut().poll_ready(&mut ctx()), Poll::Ready(Ok(()))); 337 | assert_matches!(channel.as_mut().start_send("test"), Ok(())); 338 | assert_matches!(channel.as_mut().poll_flush(&mut ctx()), Poll::Ready(Ok(()))); 339 | assert_matches!(chan_rx.try_next(), Ok(Some("test"))); 340 | } 341 | 342 | #[test] 343 | fn channel_filter_increment_channels_for_key() { 344 | use assert_matches::assert_matches; 345 | use pin_utils::pin_mut; 346 | 347 | struct TestChannel { 348 | key: &'static str, 349 | } 350 | let (_, listener) = futures::channel::mpsc::unbounded(); 351 | let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); 352 | pin_mut!(filter); 353 | let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap(); 354 | assert_eq!(Arc::strong_count(&tracker1), 1); 355 | let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap(); 356 | assert_eq!(Arc::strong_count(&tracker1), 2); 357 | assert_matches!(filter.increment_channels_for_key("key"), Err("key")); 358 | drop(tracker2); 359 | assert_eq!(Arc::strong_count(&tracker1), 1); 360 | } 361 | 362 | #[test] 363 | fn channel_filter_handle_new_channel() { 364 | use assert_matches::assert_matches; 365 | use pin_utils::pin_mut; 366 | 367 | #[derive(Debug)] 368 | struct TestChannel { 369 | key: &'static str, 370 | } 371 | let (_, listener) = futures::channel::mpsc::unbounded(); 372 | let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); 373 | pin_mut!(filter); 374 | let channel1 = filter 375 | .as_mut() 376 | .handle_new_channel(TestChannel { key: "key" }) 377 | .unwrap(); 378 | assert_eq!(Arc::strong_count(&channel1.tracker), 1); 379 | 380 | let channel2 = filter 381 | .as_mut() 382 | .handle_new_channel(TestChannel { key: "key" }) 383 | .unwrap(); 384 | assert_eq!(Arc::strong_count(&channel1.tracker), 2); 385 | 386 | assert_matches!( 387 | filter.handle_new_channel(TestChannel { key: "key" }), 388 | Err("key") 389 | ); 390 | drop(channel2); 391 | assert_eq!(Arc::strong_count(&channel1.tracker), 1); 392 | } 393 | 394 | #[test] 395 | fn channel_filter_poll_listener() { 396 | use assert_matches::assert_matches; 397 | use pin_utils::pin_mut; 398 | 399 | #[derive(Debug)] 400 | struct TestChannel { 401 | key: &'static str, 402 | } 403 | let (new_channels, listener) = futures::channel::mpsc::unbounded(); 404 | let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); 405 | pin_mut!(filter); 406 | 407 | new_channels 408 | .unbounded_send(TestChannel { key: "key" }) 409 | .unwrap(); 410 | let channel1 = 411 | assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); 412 | assert_eq!(Arc::strong_count(&channel1.tracker), 1); 413 | 414 | new_channels 415 | .unbounded_send(TestChannel { key: "key" }) 416 | .unwrap(); 417 | let _channel2 = 418 | assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); 419 | assert_eq!(Arc::strong_count(&channel1.tracker), 2); 420 | 421 | new_channels 422 | .unbounded_send(TestChannel { key: "key" }) 423 | .unwrap(); 424 | let key = 425 | assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k); 426 | assert_eq!(key, "key"); 427 | assert_eq!(Arc::strong_count(&channel1.tracker), 2); 428 | } 429 | 430 | #[test] 431 | fn channel_filter_poll_closed_channels() { 432 | use assert_matches::assert_matches; 433 | use pin_utils::pin_mut; 434 | 435 | #[derive(Debug)] 436 | struct TestChannel { 437 | key: &'static str, 438 | } 439 | let (new_channels, listener) = futures::channel::mpsc::unbounded(); 440 | let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); 441 | pin_mut!(filter); 442 | 443 | new_channels 444 | .unbounded_send(TestChannel { key: "key" }) 445 | .unwrap(); 446 | let channel = 447 | assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); 448 | assert_eq!(filter.key_counts.len(), 1); 449 | 450 | drop(channel); 451 | assert_matches!( 452 | filter.as_mut().poll_closed_channels(&mut ctx()), 453 | Poll::Ready(()) 454 | ); 455 | assert!(filter.key_counts.is_empty()); 456 | } 457 | 458 | #[test] 459 | fn channel_filter_stream() { 460 | use assert_matches::assert_matches; 461 | use pin_utils::pin_mut; 462 | 463 | #[derive(Debug)] 464 | struct TestChannel { 465 | key: &'static str, 466 | } 467 | let (new_channels, listener) = futures::channel::mpsc::unbounded(); 468 | let filter = MaxChannelsPerKey::new(listener, 2, |chan: &TestChannel| chan.key); 469 | pin_mut!(filter); 470 | 471 | new_channels 472 | .unbounded_send(TestChannel { key: "key" }) 473 | .unwrap(); 474 | let channel = assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Ready(Some(c)) => c); 475 | assert_eq!(filter.key_counts.len(), 1); 476 | 477 | drop(channel); 478 | assert_matches!(filter.as_mut().poll_next(&mut ctx()), Poll::Pending); 479 | assert!(filter.key_counts.is_empty()); 480 | } 481 | -------------------------------------------------------------------------------- /tarpc/src/server/limits/requests_per_channel.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2020 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use crate::{ 8 | server::{Channel, Config}, 9 | Response, ServerError, 10 | }; 11 | use futures::{prelude::*, ready, task::*}; 12 | use pin_project::pin_project; 13 | use std::{io, pin::Pin}; 14 | 15 | /// A [`Channel`] that limits the number of concurrent requests by throttling. 16 | /// 17 | /// Note that this is a very basic throttling heuristic. It is easy to set a number that is too low 18 | /// for the resources available to the server. For production use cases, a more advanced throttler 19 | /// is likely needed. 20 | #[pin_project] 21 | #[derive(Debug)] 22 | pub struct MaxRequests { 23 | max_in_flight_requests: usize, 24 | #[pin] 25 | inner: C, 26 | } 27 | 28 | impl MaxRequests { 29 | /// Returns the inner channel. 30 | pub fn get_ref(&self) -> &C { 31 | &self.inner 32 | } 33 | } 34 | 35 | impl MaxRequests 36 | where 37 | C: Channel, 38 | { 39 | /// Returns a new `MaxRequests` that wraps the given channel and limits concurrent requests to 40 | /// `max_in_flight_requests`. 41 | pub fn new(inner: C, max_in_flight_requests: usize) -> Self { 42 | MaxRequests { 43 | max_in_flight_requests, 44 | inner, 45 | } 46 | } 47 | } 48 | 49 | impl Stream for MaxRequests 50 | where 51 | C: Channel, 52 | { 53 | type Item = ::Item; 54 | 55 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 56 | while self.as_mut().in_flight_requests() >= *self.as_mut().project().max_in_flight_requests 57 | { 58 | ready!(self.as_mut().project().inner.poll_ready(cx)?); 59 | 60 | match ready!(self.as_mut().project().inner.poll_next(cx)?) { 61 | Some(r) => { 62 | let _entered = r.span.enter(); 63 | tracing::info!( 64 | in_flight_requests = self.as_mut().in_flight_requests(), 65 | "ThrottleRequest", 66 | ); 67 | 68 | self.as_mut().start_send(Response { 69 | request_id: r.request.id, 70 | message: Err(ServerError { 71 | kind: io::ErrorKind::WouldBlock, 72 | detail: "server throttled the request.".into(), 73 | }), 74 | })?; 75 | } 76 | None => return Poll::Ready(None), 77 | } 78 | } 79 | self.project().inner.poll_next(cx) 80 | } 81 | } 82 | 83 | impl Sink::Resp>> for MaxRequests 84 | where 85 | C: Channel, 86 | { 87 | type Error = C::Error; 88 | 89 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 90 | self.project().inner.poll_ready(cx) 91 | } 92 | 93 | fn start_send( 94 | self: Pin<&mut Self>, 95 | item: Response<::Resp>, 96 | ) -> Result<(), Self::Error> { 97 | self.project().inner.start_send(item) 98 | } 99 | 100 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 101 | self.project().inner.poll_flush(cx) 102 | } 103 | 104 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 105 | self.project().inner.poll_close(cx) 106 | } 107 | } 108 | 109 | impl AsRef for MaxRequests { 110 | fn as_ref(&self) -> &C { 111 | &self.inner 112 | } 113 | } 114 | 115 | impl Channel for MaxRequests 116 | where 117 | C: Channel, 118 | { 119 | type Req = ::Req; 120 | type Resp = ::Resp; 121 | type Transport = ::Transport; 122 | 123 | fn in_flight_requests(&self) -> usize { 124 | self.inner.in_flight_requests() 125 | } 126 | 127 | fn config(&self) -> &Config { 128 | self.inner.config() 129 | } 130 | 131 | fn transport(&self) -> &Self::Transport { 132 | self.inner.transport() 133 | } 134 | } 135 | 136 | /// An [`Incoming`](crate::server::incoming::Incoming) stream of channels that enforce limits on 137 | /// the number of in-flight requests. 138 | #[pin_project] 139 | #[derive(Debug)] 140 | pub struct MaxRequestsPerChannel { 141 | #[pin] 142 | inner: S, 143 | max_in_flight_requests: usize, 144 | } 145 | 146 | impl MaxRequestsPerChannel 147 | where 148 | S: Stream, 149 | ::Item: Channel, 150 | { 151 | pub(crate) fn new(inner: S, max_in_flight_requests: usize) -> Self { 152 | Self { 153 | inner, 154 | max_in_flight_requests, 155 | } 156 | } 157 | } 158 | 159 | impl Stream for MaxRequestsPerChannel 160 | where 161 | S: Stream, 162 | ::Item: Channel, 163 | { 164 | type Item = MaxRequests<::Item>; 165 | 166 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 167 | match ready!(self.as_mut().project().inner.poll_next(cx)) { 168 | Some(channel) => Poll::Ready(Some(MaxRequests::new( 169 | channel, 170 | *self.project().max_in_flight_requests, 171 | ))), 172 | None => Poll::Ready(None), 173 | } 174 | } 175 | } 176 | 177 | #[cfg(test)] 178 | mod tests { 179 | use super::*; 180 | 181 | use crate::server::{ 182 | testing::{self, FakeChannel, PollExt}, 183 | TrackedRequest, 184 | }; 185 | use pin_utils::pin_mut; 186 | use std::{ 187 | marker::PhantomData, 188 | time::{Duration, Instant}, 189 | }; 190 | use tracing::Span; 191 | 192 | #[tokio::test] 193 | async fn throttler_in_flight_requests() { 194 | let throttler = MaxRequests { 195 | max_in_flight_requests: 0, 196 | inner: FakeChannel::default::(), 197 | }; 198 | 199 | pin_mut!(throttler); 200 | for i in 0..5 { 201 | throttler 202 | .inner 203 | .in_flight_requests 204 | .start_request(i, Instant::now() + Duration::from_secs(1), Span::current()) 205 | .unwrap(); 206 | } 207 | assert_eq!(throttler.as_mut().in_flight_requests(), 5); 208 | } 209 | 210 | #[test] 211 | fn throttler_poll_next_done() { 212 | let throttler = MaxRequests { 213 | max_in_flight_requests: 0, 214 | inner: FakeChannel::default::(), 215 | }; 216 | 217 | pin_mut!(throttler); 218 | assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done()); 219 | } 220 | 221 | #[test] 222 | fn throttler_poll_next_some() -> io::Result<()> { 223 | let throttler = MaxRequests { 224 | max_in_flight_requests: 1, 225 | inner: FakeChannel::default::(), 226 | }; 227 | 228 | pin_mut!(throttler); 229 | throttler.inner.push_req(0, 1); 230 | assert!(throttler.as_mut().poll_ready(&mut testing::cx()).is_ready()); 231 | assert_eq!( 232 | throttler 233 | .as_mut() 234 | .poll_next(&mut testing::cx())? 235 | .map(|r| r.map(|r| (r.request.id, r.request.message))), 236 | Poll::Ready(Some((0, 1))) 237 | ); 238 | Ok(()) 239 | } 240 | 241 | #[test] 242 | fn throttler_poll_next_throttled() { 243 | let throttler = MaxRequests { 244 | max_in_flight_requests: 0, 245 | inner: FakeChannel::default::(), 246 | }; 247 | 248 | pin_mut!(throttler); 249 | throttler.inner.push_req(1, 1); 250 | assert!(throttler.as_mut().poll_next(&mut testing::cx()).is_done()); 251 | assert_eq!(throttler.inner.sink.len(), 1); 252 | let resp = throttler.inner.sink.front().unwrap(); 253 | assert_eq!(resp.request_id, 1); 254 | assert!(resp.message.is_err()); 255 | } 256 | 257 | #[test] 258 | fn throttler_poll_next_throttled_sink_not_ready() { 259 | let throttler = MaxRequests { 260 | max_in_flight_requests: 0, 261 | inner: PendingSink::default::(), 262 | }; 263 | pin_mut!(throttler); 264 | assert!(throttler.poll_next(&mut testing::cx()).is_pending()); 265 | 266 | struct PendingSink { 267 | ghost: PhantomData In>, 268 | } 269 | impl PendingSink<(), ()> { 270 | pub fn default( 271 | ) -> PendingSink>, Response> { 272 | PendingSink { ghost: PhantomData } 273 | } 274 | } 275 | impl Stream for PendingSink { 276 | type Item = In; 277 | fn poll_next(self: Pin<&mut Self>, _: &mut Context) -> Poll> { 278 | unimplemented!() 279 | } 280 | } 281 | impl Sink for PendingSink { 282 | type Error = io::Error; 283 | fn poll_ready(self: Pin<&mut Self>, _: &mut Context) -> Poll> { 284 | Poll::Pending 285 | } 286 | fn start_send(self: Pin<&mut Self>, _: Out) -> Result<(), Self::Error> { 287 | Err(io::Error::from(io::ErrorKind::WouldBlock)) 288 | } 289 | fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { 290 | Poll::Pending 291 | } 292 | fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll> { 293 | Poll::Pending 294 | } 295 | } 296 | impl Channel for PendingSink>, Response> { 297 | type Req = Req; 298 | type Resp = Resp; 299 | type Transport = (); 300 | fn config(&self) -> &Config { 301 | unimplemented!() 302 | } 303 | fn in_flight_requests(&self) -> usize { 304 | 0 305 | } 306 | fn transport(&self) -> &() { 307 | &() 308 | } 309 | } 310 | } 311 | 312 | #[tokio::test] 313 | async fn throttler_start_send() { 314 | let throttler = MaxRequests { 315 | max_in_flight_requests: 0, 316 | inner: FakeChannel::default::(), 317 | }; 318 | 319 | pin_mut!(throttler); 320 | throttler 321 | .inner 322 | .in_flight_requests 323 | .start_request(0, Instant::now() + Duration::from_secs(1), Span::current()) 324 | .unwrap(); 325 | throttler 326 | .as_mut() 327 | .start_send(Response { 328 | request_id: 0, 329 | message: Ok(1), 330 | }) 331 | .unwrap(); 332 | assert_eq!(throttler.inner.in_flight_requests.len(), 0); 333 | assert_eq!( 334 | throttler.inner.sink.front(), 335 | Some(&Response { 336 | request_id: 0, 337 | message: Ok(1), 338 | }) 339 | ); 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /tarpc/src/server/request_hook.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2022 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | //! Hooks for horizontal functionality that can run either before or after a request is executed. 8 | 9 | use crate::server::Serve; 10 | 11 | /// A request hook that runs before a request is executed. 12 | mod before; 13 | 14 | /// A request hook that runs after a request is completed. 15 | mod after; 16 | 17 | /// A request hook that runs both before a request is executed and after it is completed. 18 | mod before_and_after; 19 | 20 | pub use { 21 | after::{AfterRequest, ServeThenHook}, 22 | before::{ 23 | before, BeforeRequest, BeforeRequestCons, BeforeRequestList, BeforeRequestNil, 24 | HookThenServe, 25 | }, 26 | before_and_after::HookThenServeThenHook, 27 | }; 28 | 29 | /// Hooks that run before and/or after serving a request. 30 | pub trait RequestHook: Serve { 31 | /// Runs a hook before execution of the request. 32 | /// 33 | /// If the hook returns an error, the request will not be executed and the error will be 34 | /// returned instead. 35 | /// 36 | /// The hook can also modify the request context. This could be used, for example, to enforce a 37 | /// maximum deadline on all requests. 38 | /// 39 | /// Any type that implements [`BeforeRequest`] can be used as the hook. Types that implement 40 | /// `FnMut(&mut Context, &RequestType) -> impl Future>` can 41 | /// also be used. 42 | /// 43 | /// # Example 44 | /// 45 | /// ```rust 46 | /// use futures::{executor::block_on, future}; 47 | /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; 48 | /// use std::io; 49 | /// 50 | /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) 51 | /// .before(|_ctx: &mut context::Context, req: &i32| { 52 | /// future::ready( 53 | /// if *req == 1 { 54 | /// Err(ServerError::new( 55 | /// io::ErrorKind::Other, 56 | /// format!("I don't like {req}"))) 57 | /// } else { 58 | /// Ok(()) 59 | /// }) 60 | /// }); 61 | /// let response = serve.serve(context::current(), 1); 62 | /// assert!(block_on(response).is_err()); 63 | /// ``` 64 | fn before(self, hook: Hook) -> HookThenServe 65 | where 66 | Hook: BeforeRequest, 67 | Self: Sized, 68 | { 69 | HookThenServe::new(self, hook) 70 | } 71 | 72 | /// Runs a hook after completion of a request. 73 | /// 74 | /// The hook can modify the request context and the response. 75 | /// 76 | /// Any type that implements [`AfterRequest`] can be used as the hook. Types that implement 77 | /// `FnMut(&mut Context, &mut Result) -> impl Future` 78 | /// can also be used. 79 | /// 80 | /// # Example 81 | /// 82 | /// ```rust 83 | /// use futures::{executor::block_on, future}; 84 | /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; 85 | /// use std::io; 86 | /// 87 | /// let serve = serve( 88 | /// |_ctx, i| async move { 89 | /// if i == 1 { 90 | /// Err(ServerError::new( 91 | /// io::ErrorKind::Other, 92 | /// format!("{i} is the loneliest number"))) 93 | /// } else { 94 | /// Ok(i + 1) 95 | /// } 96 | /// }) 97 | /// .after(|_ctx: &mut context::Context, resp: &mut Result| { 98 | /// if let Err(e) = resp { 99 | /// eprintln!("server error: {e:?}"); 100 | /// } 101 | /// future::ready(()) 102 | /// }); 103 | /// 104 | /// let response = serve.serve(context::current(), 1); 105 | /// assert!(block_on(response).is_err()); 106 | /// ``` 107 | fn after(self, hook: Hook) -> ServeThenHook 108 | where 109 | Hook: AfterRequest, 110 | Self: Sized, 111 | { 112 | ServeThenHook::new(self, hook) 113 | } 114 | 115 | /// Runs a hook before and after execution of the request. 116 | /// 117 | /// If the hook returns an error, the request will not be executed and the error will be 118 | /// returned instead. 119 | /// 120 | /// The hook can also modify the request context and the response. This could be used, for 121 | /// example, to enforce a maximum deadline on all requests. 122 | /// 123 | /// # Example 124 | /// 125 | /// ```rust 126 | /// use futures::{executor::block_on, future}; 127 | /// use tarpc::{ 128 | /// context, ServerError, 129 | /// server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest, RequestHook}} 130 | /// }; 131 | /// use std::{io, time::Instant}; 132 | /// 133 | /// struct PrintLatency(Instant); 134 | /// 135 | /// impl BeforeRequest for PrintLatency { 136 | /// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { 137 | /// self.0 = Instant::now(); 138 | /// Ok(()) 139 | /// } 140 | /// } 141 | /// 142 | /// impl AfterRequest for PrintLatency { 143 | /// async fn after( 144 | /// &mut self, 145 | /// _: &mut context::Context, 146 | /// _: &mut Result, 147 | /// ) { 148 | /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); 149 | /// } 150 | /// } 151 | /// 152 | /// let serve = serve(|_ctx, i| async move { 153 | /// Ok(i + 1) 154 | /// }).before_and_after(PrintLatency(Instant::now())); 155 | /// let response = serve.serve(context::current(), 1); 156 | /// assert!(block_on(response).is_ok()); 157 | /// ``` 158 | fn before_and_after( 159 | self, 160 | hook: Hook, 161 | ) -> HookThenServeThenHook 162 | where 163 | Hook: BeforeRequest + AfterRequest, 164 | Self: Sized, 165 | { 166 | HookThenServeThenHook::new(self, hook) 167 | } 168 | } 169 | impl RequestHook for S {} 170 | -------------------------------------------------------------------------------- /tarpc/src/server/request_hook/after.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2022 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | //! Provides a hook that runs after request execution. 8 | 9 | use crate::{context, server::Serve, ServerError}; 10 | use futures::prelude::*; 11 | 12 | /// A hook that runs after request execution. 13 | #[allow(async_fn_in_trait)] 14 | pub trait AfterRequest { 15 | /// The function that is called after request execution. 16 | /// 17 | /// The hook can modify the request context and the response. 18 | async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result); 19 | } 20 | 21 | impl AfterRequest for F 22 | where 23 | F: FnMut(&mut context::Context, &mut Result) -> Fut, 24 | Fut: Future, 25 | { 26 | async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result) { 27 | self(ctx, resp).await 28 | } 29 | } 30 | 31 | /// A Service function that runs a hook after request execution. 32 | pub struct ServeThenHook { 33 | serve: Serv, 34 | hook: Hook, 35 | } 36 | 37 | impl ServeThenHook { 38 | pub(crate) fn new(serve: Serv, hook: Hook) -> Self { 39 | Self { serve, hook } 40 | } 41 | } 42 | 43 | impl Clone for ServeThenHook { 44 | fn clone(&self) -> Self { 45 | Self { 46 | serve: self.serve.clone(), 47 | hook: self.hook.clone(), 48 | } 49 | } 50 | } 51 | 52 | impl Serve for ServeThenHook 53 | where 54 | Serv: Serve, 55 | Hook: AfterRequest, 56 | { 57 | type Req = Serv::Req; 58 | type Resp = Serv::Resp; 59 | 60 | async fn serve( 61 | self, 62 | mut ctx: context::Context, 63 | req: Serv::Req, 64 | ) -> Result { 65 | let ServeThenHook { 66 | serve, mut hook, .. 67 | } = self; 68 | let mut resp = serve.serve(ctx, req).await; 69 | hook.after(&mut ctx, &mut resp).await; 70 | resp 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /tarpc/src/server/request_hook/before.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2022 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | //! Provides a hook that runs before request execution. 8 | 9 | use crate::{context, server::Serve, ServerError}; 10 | use futures::prelude::*; 11 | 12 | /// A hook that runs before request execution. 13 | #[allow(async_fn_in_trait)] 14 | pub trait BeforeRequest { 15 | /// The function that is called before request execution. 16 | /// 17 | /// If this function returns an error, the request will not be executed and the error will be 18 | /// returned instead. 19 | /// 20 | /// This function can also modify the request context. This could be used, for example, to 21 | /// enforce a maximum deadline on all requests. 22 | async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>; 23 | } 24 | 25 | /// A list of hooks that run in order before request execution. 26 | pub trait BeforeRequestList: BeforeRequest { 27 | /// The hook returned by `BeforeRequestList::then`. 28 | type Then: BeforeRequest 29 | where 30 | Next: BeforeRequest; 31 | 32 | /// Returns a hook that, when run, runs two hooks, first `self` and then `next`. 33 | fn then>(self, next: Next) -> Self::Then; 34 | 35 | /// Same as `then`, but helps the compiler with type inference when Next is a closure. 36 | fn then_fn< 37 | Next: FnMut(&mut context::Context, &Req) -> Fut, 38 | Fut: Future>, 39 | >( 40 | self, 41 | next: Next, 42 | ) -> Self::Then 43 | where 44 | Self: Sized, 45 | { 46 | self.then(next) 47 | } 48 | 49 | /// The service fn returned by `BeforeRequestList::serving`. 50 | type Serve>: Serve; 51 | 52 | /// Runs the list of request hooks before execution of the given serve fn. 53 | /// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer. 54 | fn serving>(self, serve: S) -> Self::Serve; 55 | } 56 | 57 | impl BeforeRequest for F 58 | where 59 | F: FnMut(&mut context::Context, &Req) -> Fut, 60 | Fut: Future>, 61 | { 62 | async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { 63 | self(ctx, req).await 64 | } 65 | } 66 | 67 | /// A Service function that runs a hook before request execution. 68 | #[derive(Clone)] 69 | pub struct HookThenServe { 70 | serve: Serv, 71 | hook: Hook, 72 | } 73 | 74 | impl HookThenServe { 75 | pub(crate) fn new(serve: Serv, hook: Hook) -> Self { 76 | Self { serve, hook } 77 | } 78 | } 79 | 80 | impl Serve for HookThenServe 81 | where 82 | Serv: Serve, 83 | Hook: BeforeRequest, 84 | { 85 | type Req = Serv::Req; 86 | type Resp = Serv::Resp; 87 | 88 | async fn serve( 89 | self, 90 | mut ctx: context::Context, 91 | req: Self::Req, 92 | ) -> Result { 93 | let HookThenServe { 94 | serve, mut hook, .. 95 | } = self; 96 | hook.before(&mut ctx, &req).await?; 97 | serve.serve(ctx, req).await 98 | } 99 | } 100 | 101 | /// Returns a request hook builder that runs a series of hooks before request execution. 102 | /// 103 | /// Example 104 | /// 105 | /// ```rust 106 | /// use futures::{executor::block_on, future}; 107 | /// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self, 108 | /// BeforeRequest, BeforeRequestList}}}; 109 | /// use std::{cell::Cell, io}; 110 | /// 111 | /// let i = Cell::new(0); 112 | /// let serve = request_hook::before() 113 | /// .then_fn(|_, _| async { 114 | /// assert!(i.get() == 0); 115 | /// i.set(1); 116 | /// Ok(()) 117 | /// }) 118 | /// .then_fn(|_, _| async { 119 | /// assert!(i.get() == 1); 120 | /// i.set(2); 121 | /// Ok(()) 122 | /// }) 123 | /// .serving(serve(|_ctx, i| async move { Ok(i + 1) })); 124 | /// let response = serve.clone().serve(context::current(), 1); 125 | /// assert!(block_on(response).is_ok()); 126 | /// assert!(i.get() == 2); 127 | /// ``` 128 | pub fn before() -> BeforeRequestNil { 129 | BeforeRequestNil 130 | } 131 | 132 | /// A list of hooks that run in order before a request is executed. 133 | #[derive(Clone, Copy)] 134 | pub struct BeforeRequestCons(First, Rest); 135 | 136 | /// A noop hook that runs before a request is executed. 137 | #[derive(Clone, Copy)] 138 | pub struct BeforeRequestNil; 139 | 140 | impl, Rest: BeforeRequest> BeforeRequest 141 | for BeforeRequestCons 142 | { 143 | async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { 144 | let BeforeRequestCons(first, rest) = self; 145 | first.before(ctx, req).await?; 146 | rest.before(ctx, req).await?; 147 | Ok(()) 148 | } 149 | } 150 | 151 | impl BeforeRequest for BeforeRequestNil { 152 | async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { 153 | Ok(()) 154 | } 155 | } 156 | 157 | impl, Rest: BeforeRequestList> BeforeRequestList 158 | for BeforeRequestCons 159 | { 160 | type Then 161 | = BeforeRequestCons> 162 | where 163 | Next: BeforeRequest; 164 | 165 | fn then>(self, next: Next) -> Self::Then { 166 | let BeforeRequestCons(first, rest) = self; 167 | BeforeRequestCons(first, rest.then(next)) 168 | } 169 | 170 | type Serve> = HookThenServe; 171 | 172 | fn serving>(self, serve: S) -> Self::Serve { 173 | HookThenServe::new(serve, self) 174 | } 175 | } 176 | 177 | impl BeforeRequestList for BeforeRequestNil { 178 | type Then 179 | = BeforeRequestCons 180 | where 181 | Next: BeforeRequest; 182 | 183 | fn then>(self, next: Next) -> Self::Then { 184 | BeforeRequestCons(next, BeforeRequestNil) 185 | } 186 | 187 | type Serve> = S; 188 | 189 | fn serving>(self, serve: S) -> S { 190 | serve 191 | } 192 | } 193 | 194 | #[test] 195 | fn before_request_list() { 196 | use crate::server::serve; 197 | use futures::executor::block_on; 198 | use std::cell::Cell; 199 | 200 | let i = Cell::new(0); 201 | let serve = before() 202 | .then_fn(|_, _| async { 203 | assert!(i.get() == 0); 204 | i.set(1); 205 | Ok(()) 206 | }) 207 | .then_fn(|_, _| async { 208 | assert!(i.get() == 1); 209 | i.set(2); 210 | Ok(()) 211 | }) 212 | .serving(serve(|_ctx, i| async move { Ok(i + 1) })); 213 | let response = serve.clone().serve(context::current(), 1); 214 | assert!(block_on(response).is_ok()); 215 | assert!(i.get() == 2); 216 | } 217 | -------------------------------------------------------------------------------- /tarpc/src/server/request_hook/before_and_after.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2022 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | //! Provides a hook that runs both before and after request execution. 8 | 9 | use super::{after::AfterRequest, before::BeforeRequest}; 10 | use crate::{context, server::Serve, RequestName, ServerError}; 11 | use std::marker::PhantomData; 12 | 13 | /// A Service function that runs a hook both before and after request execution. 14 | pub struct HookThenServeThenHook { 15 | serve: Serv, 16 | hook: Hook, 17 | fns: PhantomData<(fn(Req), fn(Resp))>, 18 | } 19 | 20 | impl HookThenServeThenHook { 21 | pub(crate) fn new(serve: Serv, hook: Hook) -> Self { 22 | Self { 23 | serve, 24 | hook, 25 | fns: PhantomData, 26 | } 27 | } 28 | } 29 | 30 | impl Clone for HookThenServeThenHook { 31 | fn clone(&self) -> Self { 32 | Self { 33 | serve: self.serve.clone(), 34 | hook: self.hook.clone(), 35 | fns: PhantomData, 36 | } 37 | } 38 | } 39 | 40 | impl Serve for HookThenServeThenHook 41 | where 42 | Req: RequestName, 43 | Serv: Serve, 44 | Hook: BeforeRequest + AfterRequest, 45 | { 46 | type Req = Req; 47 | type Resp = Resp; 48 | 49 | async fn serve(self, mut ctx: context::Context, req: Req) -> Result { 50 | let HookThenServeThenHook { 51 | serve, mut hook, .. 52 | } = self; 53 | hook.before(&mut ctx, &req).await?; 54 | let mut resp = serve.serve(ctx, req).await; 55 | hook.after(&mut ctx, &mut resp).await; 56 | resp 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /tarpc/src/server/testing.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2020 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use crate::{ 8 | cancellations::{cancellations, CanceledRequests, RequestCancellation}, 9 | context, 10 | server::{Channel, Config, ResponseGuard, TrackedRequest}, 11 | Request, Response, 12 | }; 13 | use futures::{task::*, Sink, Stream}; 14 | use pin_project::pin_project; 15 | use std::{collections::VecDeque, io, pin::Pin, time::Instant}; 16 | use tracing::Span; 17 | 18 | #[pin_project] 19 | pub(crate) struct FakeChannel { 20 | #[pin] 21 | pub stream: VecDeque, 22 | #[pin] 23 | pub sink: VecDeque, 24 | pub config: Config, 25 | pub in_flight_requests: super::in_flight_requests::InFlightRequests, 26 | pub request_cancellation: RequestCancellation, 27 | pub canceled_requests: CanceledRequests, 28 | } 29 | 30 | impl Stream for FakeChannel 31 | where 32 | In: Unpin, 33 | { 34 | type Item = In; 35 | 36 | fn poll_next(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { 37 | Poll::Ready(self.project().stream.pop_front()) 38 | } 39 | } 40 | 41 | impl Sink> for FakeChannel> { 42 | type Error = io::Error; 43 | 44 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 45 | self.project().sink.poll_ready(cx).map_err(|e| match e {}) 46 | } 47 | 48 | fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { 49 | self.as_mut() 50 | .project() 51 | .in_flight_requests 52 | .remove_request(response.request_id); 53 | self.project() 54 | .sink 55 | .start_send(response) 56 | .map_err(|e| match e {}) 57 | } 58 | 59 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 60 | self.project().sink.poll_flush(cx).map_err(|e| match e {}) 61 | } 62 | 63 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 64 | self.project().sink.poll_close(cx).map_err(|e| match e {}) 65 | } 66 | } 67 | 68 | impl Channel for FakeChannel>, Response> 69 | where 70 | Req: Unpin, 71 | { 72 | type Req = Req; 73 | type Resp = Resp; 74 | type Transport = (); 75 | 76 | fn config(&self) -> &Config { 77 | &self.config 78 | } 79 | 80 | fn in_flight_requests(&self) -> usize { 81 | self.in_flight_requests.len() 82 | } 83 | 84 | fn transport(&self) -> &() { 85 | &() 86 | } 87 | } 88 | 89 | impl FakeChannel>, Response> { 90 | pub fn push_req(&mut self, id: u64, message: Req) { 91 | let (_, abort_registration) = futures::future::AbortHandle::new_pair(); 92 | let (request_cancellation, _) = cancellations(); 93 | self.stream.push_back(Ok(TrackedRequest { 94 | request: Request { 95 | context: context::Context { 96 | deadline: Instant::now(), 97 | trace_context: Default::default(), 98 | }, 99 | id, 100 | message, 101 | }, 102 | abort_registration, 103 | span: Span::none(), 104 | response_guard: ResponseGuard { 105 | request_cancellation, 106 | request_id: id, 107 | cancel: false, 108 | }, 109 | })); 110 | } 111 | } 112 | 113 | impl FakeChannel<(), ()> { 114 | pub fn default() -> FakeChannel>, Response> { 115 | let (request_cancellation, canceled_requests) = cancellations(); 116 | FakeChannel { 117 | stream: Default::default(), 118 | sink: Default::default(), 119 | config: Default::default(), 120 | in_flight_requests: Default::default(), 121 | request_cancellation, 122 | canceled_requests, 123 | } 124 | } 125 | } 126 | 127 | pub trait PollExt { 128 | fn is_done(&self) -> bool; 129 | } 130 | 131 | impl PollExt for Poll> { 132 | fn is_done(&self) -> bool { 133 | matches!(self, Poll::Ready(None)) 134 | } 135 | } 136 | 137 | pub fn cx() -> Context<'static> { 138 | Context::from_waker(noop_waker_ref()) 139 | } 140 | -------------------------------------------------------------------------------- /tarpc/src/trace.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | #![deny(missing_docs, missing_debug_implementations)] 8 | 9 | //! Provides building blocks for tracing distributed programs. 10 | //! 11 | //! A trace is logically a tree of causally-related events called spans. Traces are tracked via a 12 | //! [context](Context) that identifies the current trace, span, and parent of the current span. In 13 | //! distributed systems, a context can be sent from client to server to connect events occurring on 14 | //! either side. 15 | //! 16 | //! This crate's design is based on [opencensus 17 | //! tracing](https://opencensus.io/core-concepts/tracing/). 18 | 19 | use opentelemetry::trace::TraceContextExt; 20 | use rand::Rng; 21 | use std::{ 22 | convert::TryFrom, 23 | fmt::{self, Formatter}, 24 | num::{NonZeroU128, NonZeroU64}, 25 | }; 26 | use tracing_opentelemetry::OpenTelemetrySpanExt; 27 | 28 | /// A context for tracing the execution of processes, distributed or otherwise. 29 | /// 30 | /// Consists of a span identifying an event, an optional parent span identifying a causal event 31 | /// that triggered the current span, and a trace with which all related spans are associated. 32 | #[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)] 33 | #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] 34 | pub struct Context { 35 | /// An identifier of the trace associated with the current context. A trace ID is typically 36 | /// created at a root span and passed along through all causal events. 37 | pub trace_id: TraceId, 38 | /// An identifier of the current span. In typical RPC usage, a span is created by a client 39 | /// before making an RPC, and the span ID is sent to the server. The server is free to create 40 | /// its own spans, for which it sets the client's span as the parent span. 41 | pub span_id: SpanId, 42 | /// Indicates whether a sampler has already decided whether or not to sample the trace 43 | /// associated with the Context. If `sampling_decision` is None, then a decision has not yet 44 | /// been made. Downstream samplers do not need to abide by "no sample" decisions--for example, 45 | /// an upstream client may choose to never sample, which may not make sense for the client's 46 | /// dependencies. On the other hand, if an upstream process has chosen to sample this trace, 47 | /// then the downstream samplers are expected to respect that decision and also sample the 48 | /// trace. Otherwise, the full trace would not be able to be reconstructed. 49 | pub sampling_decision: SamplingDecision, 50 | } 51 | 52 | /// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the 53 | /// same trace ID. 54 | #[derive(Default, PartialEq, Eq, Hash, Clone, Copy)] 55 | #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] 56 | pub struct TraceId(#[cfg_attr(feature = "serde1", serde(with = "u128_serde"))] u128); 57 | 58 | /// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace. 59 | #[derive(Default, PartialEq, Eq, Hash, Clone, Copy)] 60 | #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] 61 | pub struct SpanId(u64); 62 | 63 | /// Indicates whether a sampler has decided whether or not to sample the trace associated with the 64 | /// Context. Downstream samplers do not need to abide by "no sample" decisions--for example, an 65 | /// upstream client may choose to never sample, which may not make sense for the client's 66 | /// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then 67 | /// the downstream samplers are expected to respect that decision and also sample the trace. 68 | /// Otherwise, the full trace would not be able to be reconstructed reliably. 69 | #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] 70 | #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] 71 | #[repr(u8)] 72 | pub enum SamplingDecision { 73 | /// The associated span was sampled by its creating process. Child spans must also be sampled. 74 | Sampled, 75 | /// The associated span was not sampled by its creating process. 76 | Unsampled, 77 | } 78 | 79 | impl Context { 80 | /// Constructs a new context with the trace ID and sampling decision inherited from the parent. 81 | pub(crate) fn new_child(&self) -> Self { 82 | Self { 83 | trace_id: self.trace_id, 84 | span_id: SpanId::random(&mut rand::thread_rng()), 85 | sampling_decision: self.sampling_decision, 86 | } 87 | } 88 | } 89 | 90 | impl TraceId { 91 | /// Returns a random trace ID that can be assumed to be globally unique if `rng` generates 92 | /// actually-random numbers. 93 | pub fn random(rng: &mut R) -> Self { 94 | TraceId(rng.gen::().get()) 95 | } 96 | 97 | /// Returns true iff the trace ID is 0. 98 | pub fn is_none(&self) -> bool { 99 | self.0 == 0 100 | } 101 | } 102 | 103 | impl SpanId { 104 | /// Returns a random span ID that can be assumed to be unique within a single trace. 105 | pub fn random(rng: &mut R) -> Self { 106 | SpanId(rng.gen::().get()) 107 | } 108 | 109 | /// Returns true iff the span ID is 0. 110 | pub fn is_none(&self) -> bool { 111 | self.0 == 0 112 | } 113 | } 114 | 115 | impl From for u128 { 116 | fn from(trace_id: TraceId) -> Self { 117 | trace_id.0 118 | } 119 | } 120 | 121 | impl From for TraceId { 122 | fn from(trace_id: u128) -> Self { 123 | Self(trace_id) 124 | } 125 | } 126 | 127 | impl From for u64 { 128 | fn from(span_id: SpanId) -> Self { 129 | span_id.0 130 | } 131 | } 132 | 133 | impl From for SpanId { 134 | fn from(span_id: u64) -> Self { 135 | Self(span_id) 136 | } 137 | } 138 | 139 | impl From for TraceId { 140 | fn from(trace_id: opentelemetry::trace::TraceId) -> Self { 141 | Self::from(u128::from_be_bytes(trace_id.to_bytes())) 142 | } 143 | } 144 | 145 | impl From for opentelemetry::trace::TraceId { 146 | fn from(trace_id: TraceId) -> Self { 147 | Self::from_bytes(u128::from(trace_id).to_be_bytes()) 148 | } 149 | } 150 | 151 | impl From for SpanId { 152 | fn from(span_id: opentelemetry::trace::SpanId) -> Self { 153 | Self::from(u64::from_be_bytes(span_id.to_bytes())) 154 | } 155 | } 156 | 157 | impl From for opentelemetry::trace::SpanId { 158 | fn from(span_id: SpanId) -> Self { 159 | Self::from_bytes(u64::from(span_id).to_be_bytes()) 160 | } 161 | } 162 | 163 | impl TryFrom<&tracing::Span> for Context { 164 | type Error = NoActiveSpan; 165 | 166 | fn try_from(span: &tracing::Span) -> Result { 167 | let context = span.context(); 168 | if context.has_active_span() { 169 | Ok(Self::from(context.span())) 170 | } else { 171 | Err(NoActiveSpan) 172 | } 173 | } 174 | } 175 | 176 | impl From> for Context { 177 | fn from(span: opentelemetry::trace::SpanRef<'_>) -> Self { 178 | let otel_ctx = span.span_context(); 179 | Self { 180 | trace_id: TraceId::from(otel_ctx.trace_id()), 181 | span_id: SpanId::from(otel_ctx.span_id()), 182 | sampling_decision: SamplingDecision::from(otel_ctx), 183 | } 184 | } 185 | } 186 | 187 | impl From for opentelemetry::trace::TraceFlags { 188 | fn from(decision: SamplingDecision) -> Self { 189 | match decision { 190 | SamplingDecision::Sampled => opentelemetry::trace::TraceFlags::SAMPLED, 191 | SamplingDecision::Unsampled => opentelemetry::trace::TraceFlags::default(), 192 | } 193 | } 194 | } 195 | 196 | impl From<&opentelemetry::trace::SpanContext> for SamplingDecision { 197 | fn from(context: &opentelemetry::trace::SpanContext) -> Self { 198 | if context.is_sampled() { 199 | SamplingDecision::Sampled 200 | } else { 201 | SamplingDecision::Unsampled 202 | } 203 | } 204 | } 205 | 206 | impl Default for SamplingDecision { 207 | fn default() -> Self { 208 | Self::Unsampled 209 | } 210 | } 211 | 212 | /// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span). 213 | #[derive(Debug)] 214 | pub struct NoActiveSpan; 215 | 216 | impl fmt::Display for TraceId { 217 | fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { 218 | write!(f, "{:02x}", self.0)?; 219 | Ok(()) 220 | } 221 | } 222 | 223 | impl fmt::Debug for TraceId { 224 | fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { 225 | write!(f, "{:02x}", self.0)?; 226 | Ok(()) 227 | } 228 | } 229 | 230 | impl fmt::Display for SpanId { 231 | fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { 232 | write!(f, "{:02x}", self.0)?; 233 | Ok(()) 234 | } 235 | } 236 | 237 | impl fmt::Debug for SpanId { 238 | fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { 239 | write!(f, "{:02x}", self.0)?; 240 | Ok(()) 241 | } 242 | } 243 | 244 | #[cfg(feature = "serde1")] 245 | mod u128_serde { 246 | pub fn serialize(u: &u128, serializer: S) -> Result 247 | where 248 | S: serde::Serializer, 249 | { 250 | serde::Serialize::serialize(&u.to_le_bytes(), serializer) 251 | } 252 | 253 | pub fn deserialize<'de, D>(deserializer: D) -> Result 254 | where 255 | D: serde::Deserializer<'de>, 256 | { 257 | Ok(u128::from_le_bytes(serde::Deserialize::deserialize( 258 | deserializer, 259 | )?)) 260 | } 261 | } 262 | -------------------------------------------------------------------------------- /tarpc/src/transport.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | //! Provides a [`Transport`](sealed::Transport) trait as well as implementations. 8 | //! 9 | //! The rpc crate is transport- and protocol-agnostic. Any transport that impls [`Transport`](sealed::Transport) 10 | //! can be plugged in, using whatever protocol it wants. 11 | 12 | pub mod channel; 13 | 14 | pub(crate) mod sealed { 15 | use futures::prelude::*; 16 | use std::error::Error; 17 | 18 | /// A bidirectional stream ([`Sink`] + [`Stream`]) of messages. 19 | pub trait Transport 20 | where 21 | Self: Stream>::Error>>, 22 | Self: Sink>::TransportError>, 23 | >::Error: Error, 24 | { 25 | /// Associated type where clauses are not elaborated; this associated type allows users 26 | /// bounding types by Transport to avoid having to explicitly add `T::Error: Error` to their 27 | /// bounds. 28 | type TransportError: Error + Send + Sync + 'static; 29 | } 30 | 31 | impl Transport for T 32 | where 33 | T: ?Sized, 34 | T: Stream>, 35 | T: Sink, 36 | T::Error: Error + Send + Sync + 'static, 37 | { 38 | type TransportError = E; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /tarpc/src/transport/channel.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | //! Transports backed by in-memory channels. 8 | 9 | use futures::{task::*, Sink, Stream}; 10 | use pin_project::pin_project; 11 | use std::{error::Error, pin::Pin}; 12 | use tokio::sync::mpsc; 13 | 14 | /// Errors that occur in the sending or receiving of messages over a channel. 15 | #[derive(thiserror::Error, Debug)] 16 | pub enum ChannelError { 17 | /// An error occurred readying to send into the channel. 18 | #[error("an error occurred readying to send into the channel")] 19 | Ready(#[source] Box), 20 | /// An error occurred sending into the channel. 21 | #[error("an error occurred sending into the channel")] 22 | Send(#[source] Box), 23 | /// An error occurred receiving from the channel. 24 | #[error("an error occurred receiving from the channel")] 25 | Receive(#[source] Box), 26 | } 27 | 28 | /// Returns two unbounded channel peers. Each [`Stream`] yields items sent through the other's 29 | /// [`Sink`]. 30 | pub fn unbounded() -> ( 31 | UnboundedChannel, 32 | UnboundedChannel, 33 | ) { 34 | let (tx1, rx2) = mpsc::unbounded_channel(); 35 | let (tx2, rx1) = mpsc::unbounded_channel(); 36 | ( 37 | UnboundedChannel { tx: tx1, rx: rx1 }, 38 | UnboundedChannel { tx: tx2, rx: rx2 }, 39 | ) 40 | } 41 | 42 | /// A bi-directional channel backed by an [`UnboundedSender`](mpsc::UnboundedSender) 43 | /// and [`UnboundedReceiver`](mpsc::UnboundedReceiver). 44 | #[derive(Debug)] 45 | pub struct UnboundedChannel { 46 | rx: mpsc::UnboundedReceiver, 47 | tx: mpsc::UnboundedSender, 48 | } 49 | 50 | impl Stream for UnboundedChannel { 51 | type Item = Result; 52 | 53 | fn poll_next( 54 | mut self: Pin<&mut Self>, 55 | cx: &mut Context<'_>, 56 | ) -> Poll>> { 57 | self.rx 58 | .poll_recv(cx) 59 | .map(|option| option.map(Ok)) 60 | .map_err(ChannelError::Receive) 61 | } 62 | } 63 | 64 | const CLOSED_MESSAGE: &str = "the channel is closed and cannot accept new items for sending"; 65 | 66 | impl Sink for UnboundedChannel { 67 | type Error = ChannelError; 68 | 69 | fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { 70 | Poll::Ready(if self.tx.is_closed() { 71 | Err(ChannelError::Ready(CLOSED_MESSAGE.into())) 72 | } else { 73 | Ok(()) 74 | }) 75 | } 76 | 77 | fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { 78 | self.tx 79 | .send(item) 80 | .map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into())) 81 | } 82 | 83 | fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { 84 | // UnboundedSender requires no flushing. 85 | Poll::Ready(Ok(())) 86 | } 87 | 88 | fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { 89 | // UnboundedSender can't initiate closure. 90 | Poll::Ready(Ok(())) 91 | } 92 | } 93 | 94 | /// Returns two channel peers with buffer equal to `capacity`. Each [`Stream`] yields items sent 95 | /// through the other's [`Sink`]. 96 | pub fn bounded( 97 | capacity: usize, 98 | ) -> (Channel, Channel) { 99 | let (tx1, rx2) = futures::channel::mpsc::channel(capacity); 100 | let (tx2, rx1) = futures::channel::mpsc::channel(capacity); 101 | (Channel { tx: tx1, rx: rx1 }, Channel { tx: tx2, rx: rx2 }) 102 | } 103 | 104 | /// A bi-directional channel backed by a [`Sender`](futures::channel::mpsc::Sender) 105 | /// and [`Receiver`](futures::channel::mpsc::Receiver). 106 | #[pin_project] 107 | #[derive(Debug)] 108 | pub struct Channel { 109 | #[pin] 110 | rx: futures::channel::mpsc::Receiver, 111 | #[pin] 112 | tx: futures::channel::mpsc::Sender, 113 | } 114 | 115 | impl Stream for Channel { 116 | type Item = Result; 117 | 118 | fn poll_next( 119 | self: Pin<&mut Self>, 120 | cx: &mut Context<'_>, 121 | ) -> Poll>> { 122 | self.project() 123 | .rx 124 | .poll_next(cx) 125 | .map(|option| option.map(Ok)) 126 | .map_err(ChannelError::Receive) 127 | } 128 | } 129 | 130 | impl Sink for Channel { 131 | type Error = ChannelError; 132 | 133 | fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 134 | self.project() 135 | .tx 136 | .poll_ready(cx) 137 | .map_err(|e| ChannelError::Ready(Box::new(e))) 138 | } 139 | 140 | fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> { 141 | self.project() 142 | .tx 143 | .start_send(item) 144 | .map_err(|e| ChannelError::Send(Box::new(e))) 145 | } 146 | 147 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 148 | self.project() 149 | .tx 150 | .poll_flush(cx) 151 | .map_err(|e| ChannelError::Send(Box::new(e))) 152 | } 153 | 154 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 155 | self.project() 156 | .tx 157 | .poll_close(cx) 158 | .map_err(|e| ChannelError::Send(Box::new(e))) 159 | } 160 | } 161 | 162 | #[cfg(all(test, feature = "tokio1"))] 163 | mod tests { 164 | use crate::{ 165 | client::{self, RpcError}, 166 | context, 167 | server::{incoming::Incoming, serve, BaseChannel}, 168 | transport::{ 169 | self, 170 | channel::{Channel, UnboundedChannel}, 171 | }, 172 | ServerError, 173 | }; 174 | use assert_matches::assert_matches; 175 | use futures::{prelude::*, stream}; 176 | use std::io; 177 | use tracing::trace; 178 | 179 | #[test] 180 | fn ensure_is_transport() { 181 | fn is_transport>() {} 182 | is_transport::<(), (), UnboundedChannel<(), ()>>(); 183 | is_transport::<(), (), Channel<(), ()>>(); 184 | } 185 | 186 | #[tokio::test] 187 | async fn integration() -> anyhow::Result<()> { 188 | let _ = tracing_subscriber::fmt::try_init(); 189 | 190 | let (client_channel, server_channel) = transport::channel::unbounded(); 191 | tokio::spawn( 192 | stream::once(future::ready(server_channel)) 193 | .map(BaseChannel::with_defaults) 194 | .execute(serve(|_ctx, request: String| async move { 195 | request.parse::().map_err(|_| { 196 | ServerError::new( 197 | io::ErrorKind::InvalidInput, 198 | format!("{request:?} is not an int"), 199 | ) 200 | }) 201 | })) 202 | .for_each(|channel| async move { 203 | tokio::spawn(channel.for_each(|response| response)); 204 | }), 205 | ); 206 | 207 | let client = client::new(client::Config::default(), client_channel).spawn(); 208 | 209 | let response1 = client.call(context::current(), "123".into()).await; 210 | let response2 = client.call(context::current(), "abc".into()).await; 211 | 212 | trace!("response1: {:?}, response2: {:?}", response1, response2); 213 | 214 | assert_matches!(response1, Ok(123)); 215 | assert_matches!(response2, Err(RpcError::Server(e)) if e.kind == io::ErrorKind::InvalidInput); 216 | 217 | Ok(()) 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /tarpc/src/util.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use std::{ 8 | collections::HashMap, 9 | hash::{BuildHasher, Hash}, 10 | time::{Duration, Instant}, 11 | }; 12 | 13 | #[cfg(feature = "serde1")] 14 | #[cfg_attr(docsrs, doc(cfg(feature = "serde1")))] 15 | pub mod serde; 16 | 17 | /// Extension trait for [Instants](Instant) in the future, i.e. deadlines. 18 | pub trait TimeUntil { 19 | /// How much time from now until this time is reached. 20 | fn time_until(&self) -> Duration; 21 | } 22 | 23 | impl TimeUntil for Instant { 24 | fn time_until(&self) -> Duration { 25 | self.duration_since(Instant::now()) 26 | } 27 | } 28 | 29 | /// Collection compaction; configurable `shrink_to_fit`. 30 | pub trait Compact { 31 | /// Compacts space if the ratio of length : capacity is less than `usage_ratio_threshold`. 32 | fn compact(&mut self, usage_ratio_threshold: f64); 33 | } 34 | 35 | impl Compact for HashMap 36 | where 37 | K: Eq + Hash, 38 | H: BuildHasher, 39 | { 40 | fn compact(&mut self, usage_ratio_threshold: f64) { 41 | let usage_ratio_threshold = usage_ratio_threshold.clamp(f64::MIN_POSITIVE, 1.); 42 | let cap = f64::max(1000., self.len() as f64 / usage_ratio_threshold); 43 | self.shrink_to(cap as usize); 44 | } 45 | } 46 | 47 | #[test] 48 | fn test_compact() { 49 | let mut map = HashMap::with_capacity(2048); 50 | assert_eq!(map.capacity(), 3584); 51 | 52 | // Make usage ratio 25% 53 | for i in 0..896 { 54 | map.insert(format!("k{i}"), "v"); 55 | } 56 | 57 | map.compact(-1.0); 58 | assert_eq!(map.capacity(), 3584); 59 | 60 | map.compact(0.25); 61 | assert_eq!(map.capacity(), 3584); 62 | 63 | map.compact(0.50); 64 | assert_eq!(map.capacity(), 1792); 65 | 66 | map.compact(1.0); 67 | assert_eq!(map.capacity(), 1792); 68 | 69 | map.compact(2.0); 70 | assert_eq!(map.capacity(), 1792); 71 | } 72 | -------------------------------------------------------------------------------- /tarpc/src/util/serde.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Use of this source code is governed by an MIT-style 4 | // license that can be found in the LICENSE file or at 5 | // https://opensource.org/licenses/MIT. 6 | 7 | use serde::{Deserialize, Deserializer, Serialize, Serializer}; 8 | use std::io; 9 | 10 | /// Serializes [`io::ErrorKind`] as a `u32`. 11 | #[allow(clippy::trivially_copy_pass_by_ref)] // Exact fn signature required by serde derive 12 | pub fn serialize_io_error_kind_as_u32( 13 | kind: &io::ErrorKind, 14 | serializer: S, 15 | ) -> Result 16 | where 17 | S: Serializer, 18 | { 19 | use std::io::ErrorKind::*; 20 | match *kind { 21 | NotFound => 0, 22 | PermissionDenied => 1, 23 | ConnectionRefused => 2, 24 | ConnectionReset => 3, 25 | ConnectionAborted => 4, 26 | NotConnected => 5, 27 | AddrInUse => 6, 28 | AddrNotAvailable => 7, 29 | BrokenPipe => 8, 30 | AlreadyExists => 9, 31 | WouldBlock => 10, 32 | InvalidInput => 11, 33 | InvalidData => 12, 34 | TimedOut => 13, 35 | WriteZero => 14, 36 | Interrupted => 15, 37 | Other => 16, 38 | UnexpectedEof => 17, 39 | _ => 16, 40 | } 41 | .serialize(serializer) 42 | } 43 | 44 | /// Deserializes [`io::ErrorKind`] from a `u32`. 45 | pub fn deserialize_io_error_kind_from_u32<'de, D>( 46 | deserializer: D, 47 | ) -> Result 48 | where 49 | D: Deserializer<'de>, 50 | { 51 | use std::io::ErrorKind::*; 52 | Ok(match u32::deserialize(deserializer)? { 53 | 0 => NotFound, 54 | 1 => PermissionDenied, 55 | 2 => ConnectionRefused, 56 | 3 => ConnectionReset, 57 | 4 => ConnectionAborted, 58 | 5 => NotConnected, 59 | 6 => AddrInUse, 60 | 7 => AddrNotAvailable, 61 | 8 => BrokenPipe, 62 | 9 => AlreadyExists, 63 | 10 => WouldBlock, 64 | 11 => InvalidInput, 65 | 12 => InvalidData, 66 | 13 => TimedOut, 67 | 14 => WriteZero, 68 | 15 => Interrupted, 69 | 16 => Other, 70 | 17 => UnexpectedEof, 71 | _ => Other, 72 | }) 73 | } 74 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail.rs: -------------------------------------------------------------------------------- 1 | #[test] 2 | fn ui() { 3 | let t = trybuild::TestCases::new(); 4 | t.compile_fail("tests/compile_fail/*.rs"); 5 | #[cfg(all(feature = "serde-transport", feature = "tcp"))] 6 | t.compile_fail("tests/compile_fail/serde_transport/*.rs"); 7 | #[cfg(not(feature = "serde1"))] 8 | t.compile_fail("tests/compile_fail/no_serde1/*.rs"); 9 | #[cfg(feature = "serde1")] 10 | t.compile_fail("tests/compile_fail/serde1/*.rs"); 11 | } 12 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/must_use_request_dispatch.rs: -------------------------------------------------------------------------------- 1 | use tarpc::client; 2 | 3 | #[tarpc::service] 4 | trait World { 5 | async fn hello(name: String) -> String; 6 | } 7 | 8 | fn main() { 9 | let (client_transport, _) = tarpc::transport::channel::unbounded(); 10 | 11 | #[deny(unused_must_use)] 12 | { 13 | WorldClient::new(client::Config::default(), client_transport).dispatch; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/must_use_request_dispatch.stderr: -------------------------------------------------------------------------------- 1 | error: unused `RequestDispatch` that must be used 2 | --> tests/compile_fail/must_use_request_dispatch.rs:13:9 3 | | 4 | 13 | WorldClient::new(client::Config::default(), client_transport).dispatch; 5 | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 6 | | 7 | note: the lint level is defined here 8 | --> tests/compile_fail/must_use_request_dispatch.rs:11:12 9 | | 10 | 11 | #[deny(unused_must_use)] 11 | | ^^^^^^^^^^^^^^^ 12 | help: use `let _ = ...` to ignore the resulting value 13 | | 14 | 13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; 15 | | +++++++ 16 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.rs: -------------------------------------------------------------------------------- 1 | #[tarpc::service(derive_serde = true)] 2 | trait Foo { 3 | async fn foo(); 4 | } 5 | 6 | fn main() { 7 | let x = FooRequest::Foo {}; 8 | x.serialize(); 9 | } 10 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.stderr: -------------------------------------------------------------------------------- 1 | error: To enable serde, first enable the `serde1` feature of tarpc 2 | --> tests/compile_fail/no_serde1/no_explicit_serde_without_feature.rs:1:18 3 | | 4 | 1 | #[tarpc::service(derive_serde = true)] 5 | | ^^^^^^^^^^^^ 6 | 7 | error[E0433]: failed to resolve: use of undeclared type `FooRequest` 8 | --> tests/compile_fail/no_serde1/no_explicit_serde_without_feature.rs:7:13 9 | | 10 | 7 | let x = FooRequest::Foo {}; 11 | | ^^^^^^^^^^ use of undeclared type `FooRequest` 12 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.rs: -------------------------------------------------------------------------------- 1 | #[tarpc::service] 2 | trait Foo { 3 | async fn foo(); 4 | } 5 | 6 | fn main() { 7 | let x = FooRequest::Foo {}; 8 | x.serialize(); 9 | } 10 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.stderr: -------------------------------------------------------------------------------- 1 | error[E0599]: no method named `serialize` found for enum `FooRequest` in the current scope 2 | --> tests/compile_fail/no_serde1/no_implicit_serde_without_feature.rs:8:7 3 | | 4 | 1 | #[tarpc::service] 5 | | ----------------- method `serialize` not found for this enum 6 | ... 7 | 8 | x.serialize(); 8 | | ^^^^^^^^^ method not found in `FooRequest` 9 | | 10 | = help: items from traits can only be used if the trait is implemented and in scope 11 | = note: the following trait defines an item `serialize`, perhaps you need to implement it: 12 | candidate #1: `serde::ser::Serialize` 13 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/serde1/deprecated.rs: -------------------------------------------------------------------------------- 1 | #![deny(warnings)] 2 | 3 | #[tarpc::service(derive_serde = true)] 4 | trait Foo { 5 | async fn foo(); 6 | } 7 | 8 | fn main() {} 9 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/serde1/deprecated.stderr: -------------------------------------------------------------------------------- 1 | error: use of deprecated constant `_::DEPRECATED_SYNTAX`: 2 | The form `tarpc::service(derive_serde = true)` is deprecated. 3 | Use `tarpc::service(derive = [Serialize, Deserialize])`. 4 | --> tests/compile_fail/serde1/deprecated.rs:3:1 5 | | 6 | 3 | #[tarpc::service(derive_serde = true)] 7 | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 | | 9 | note: the lint level is defined here 10 | --> tests/compile_fail/serde1/deprecated.rs:1:9 11 | | 12 | 1 | #![deny(warnings)] 13 | | ^^^^^^^^ 14 | = note: `#[deny(deprecated)]` implied by `#[deny(warnings)]` 15 | = note: this error originates in the attribute macro `tarpc::service` (in Nightly builds, run with -Z macro-backtrace for more info) 16 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/serde1/incompatible.rs: -------------------------------------------------------------------------------- 1 | #![allow(deprecated)] 2 | #[tarpc::service(derive = [Clone], derive_serde = true)] 3 | trait Foo { 4 | async fn foo(); 5 | } 6 | 7 | fn main() {} 8 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/serde1/incompatible.stderr: -------------------------------------------------------------------------------- 1 | error: tarpc does not support `derive_serde` and `derive` at the same time 2 | --> tests/compile_fail/serde1/incompatible.rs:2:1 3 | | 4 | 2 | #[tarpc::service(derive = [Clone], derive_serde = true)] 5 | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 6 | | 7 | = note: this error originates in the attribute macro `tarpc::service` (in Nightly builds, run with -Z macro-backtrace for more info) 8 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/serde1/opt_out_serde.rs: -------------------------------------------------------------------------------- 1 | #![allow(deprecated)] 2 | 3 | use std::fmt::Formatter; 4 | 5 | #[tarpc::service(derive_serde = false)] 6 | trait Foo { 7 | async fn foo(); 8 | } 9 | 10 | fn foo(f: &mut Formatter) { 11 | let x = FooRequest::Foo {}; 12 | tarpc::serde::Serialize::serialize(&x, f); 13 | } 14 | 15 | fn main() {} 16 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/serde1/opt_out_serde.stderr: -------------------------------------------------------------------------------- 1 | error[E0277]: the trait bound `FooRequest: Serialize` is not satisfied 2 | --> tests/compile_fail/serde1/opt_out_serde.rs:12:40 3 | | 4 | 12 | tarpc::serde::Serialize::serialize(&x, f); 5 | | ---------------------------------- ^^ the trait `Serialize` is not implemented for `FooRequest` 6 | | | 7 | | required by a bound introduced by this call 8 | | 9 | = note: for local types consider adding `#[derive(serde::Serialize)]` to your `FooRequest` type 10 | = note: for types from other crates check whether the crate offers a `serde` feature flag 11 | = help: the following other types implement trait `Serialize`: 12 | &'a T 13 | &'a mut T 14 | () 15 | (T,) 16 | (T0, T1) 17 | (T0, T1, T2) 18 | (T0, T1, T2, T3) 19 | (T0, T1, T2, T3, T4) 20 | and $N others 21 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.rs: -------------------------------------------------------------------------------- 1 | use tarpc::serde_transport; 2 | use tokio_serde::formats::Json; 3 | 4 | fn main() { 5 | #[deny(unused_must_use)] 6 | { 7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default); 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/serde_transport/must_use_tcp_connect.stderr: -------------------------------------------------------------------------------- 1 | error: unused `TcpConnect` that must be used 2 | --> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:7:9 3 | | 4 | 7 | serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default); 5 | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 6 | | 7 | note: the lint level is defined here 8 | --> tests/compile_fail/serde_transport/must_use_tcp_connect.rs:5:12 9 | | 10 | 5 | #[deny(unused_must_use)] 11 | | ^^^^^^^^^^^^^^^ 12 | help: use `let _ = ...` to ignore the resulting value 13 | | 14 | 7 | let _ = serde_transport::tcp::connect::<_, (), (), _, _>("0.0.0.0:0", Json::default); 15 | | +++++++ 16 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/tarpc_service_arg_pat.rs: -------------------------------------------------------------------------------- 1 | #[tarpc::service] 2 | trait World { 3 | async fn pat((a, b): (u8, u32)); 4 | } 5 | 6 | fn main() {} 7 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/tarpc_service_arg_pat.stderr: -------------------------------------------------------------------------------- 1 | error: patterns aren't allowed in RPC args 2 | --> $DIR/tarpc_service_arg_pat.rs:3:18 3 | | 4 | 3 | async fn pat((a, b): (u8, u32)); 5 | | ^^^^^^ 6 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/tarpc_service_derive_serde.rs: -------------------------------------------------------------------------------- 1 | #[tarpc::service(derive_serde = loop {})] 2 | trait World { 3 | async fn hello(); 4 | } 5 | 6 | fn main() {} 7 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/tarpc_service_derive_serde.stderr: -------------------------------------------------------------------------------- 1 | error: expected literal 2 | --> tests/compile_fail/tarpc_service_derive_serde.rs:1:33 3 | | 4 | 1 | #[tarpc::service(derive_serde = loop {})] 5 | | ^^^^ 6 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/tarpc_service_fn_new.rs: -------------------------------------------------------------------------------- 1 | #[tarpc::service] 2 | trait World { 3 | async fn new(); 4 | } 5 | 6 | fn main() {} 7 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/tarpc_service_fn_new.stderr: -------------------------------------------------------------------------------- 1 | error: method name conflicts with generated fn `WorldClient::new` 2 | --> $DIR/tarpc_service_fn_new.rs:3:14 3 | | 4 | 3 | async fn new(); 5 | | ^^^ 6 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/tarpc_service_fn_serve.rs: -------------------------------------------------------------------------------- 1 | #[tarpc::service] 2 | trait World { 3 | async fn serve(); 4 | } 5 | 6 | fn main() {} 7 | -------------------------------------------------------------------------------- /tarpc/tests/compile_fail/tarpc_service_fn_serve.stderr: -------------------------------------------------------------------------------- 1 | error: method name conflicts with generated fn `World::serve` 2 | --> $DIR/tarpc_service_fn_serve.rs:3:14 3 | | 4 | 3 | async fn serve(); 5 | | ^^^^^ 6 | -------------------------------------------------------------------------------- /tarpc/tests/dataservice.rs: -------------------------------------------------------------------------------- 1 | use futures::prelude::*; 2 | use tarpc::serde_transport; 3 | use tarpc::{ 4 | client, context, 5 | server::{incoming::Incoming, BaseChannel}, 6 | }; 7 | use tokio_serde::formats::Json; 8 | 9 | #[tarpc::derive_serde] 10 | #[derive(Debug, PartialEq, Eq)] 11 | pub enum TestData { 12 | Black, 13 | White, 14 | } 15 | 16 | #[tarpc::service] 17 | pub trait ColorProtocol { 18 | async fn get_opposite_color(color: TestData) -> TestData; 19 | } 20 | 21 | #[derive(Clone)] 22 | struct ColorServer; 23 | 24 | impl ColorProtocol for ColorServer { 25 | async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { 26 | match color { 27 | TestData::White => TestData::Black, 28 | TestData::Black => TestData::White, 29 | } 30 | } 31 | } 32 | 33 | #[cfg(test)] 34 | async fn spawn(fut: impl Future + Send + 'static) { 35 | tokio::spawn(fut); 36 | } 37 | 38 | #[tokio::test] 39 | async fn test_call() -> anyhow::Result<()> { 40 | let transport = tarpc::serde_transport::tcp::listen("localhost:56797", Json::default).await?; 41 | let addr = transport.local_addr(); 42 | tokio::spawn( 43 | transport 44 | .take(1) 45 | .filter_map(|r| async { r.ok() }) 46 | .map(BaseChannel::with_defaults) 47 | .execute(ColorServer.serve()) 48 | .map(|channel| channel.for_each(spawn)) 49 | .for_each(spawn), 50 | ); 51 | 52 | let transport = serde_transport::tcp::connect(addr, Json::default).await?; 53 | let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); 54 | 55 | let color = client 56 | .get_opposite_color(context::current(), TestData::White) 57 | .await?; 58 | assert_eq!(color, TestData::Black); 59 | 60 | Ok(()) 61 | } 62 | -------------------------------------------------------------------------------- /tarpc/tests/proc_macro_hygene.rs: -------------------------------------------------------------------------------- 1 | #![no_implicit_prelude] 2 | extern crate tarpc as some_random_other_name; 3 | 4 | #[cfg(feature = "serde1")] 5 | mod serde1_feature { 6 | #[::tarpc::derive_serde] 7 | #[derive(Debug, PartialEq, Eq)] 8 | pub enum TestData { 9 | Black, 10 | White, 11 | } 12 | } 13 | 14 | #[::tarpc::service] 15 | pub trait ColorProtocol { 16 | async fn get_opposite_color(color: u8) -> u8; 17 | } 18 | -------------------------------------------------------------------------------- /tarpc/tests/service_functional.rs: -------------------------------------------------------------------------------- 1 | use assert_matches::assert_matches; 2 | use futures::{ 3 | future::{join_all, ready}, 4 | prelude::*, 5 | }; 6 | use std::time::{Duration, Instant}; 7 | use tarpc::{ 8 | client::{self}, 9 | context, 10 | server::{incoming::Incoming, BaseChannel, Channel}, 11 | transport::channel, 12 | }; 13 | use tokio::join; 14 | 15 | #[tarpc_plugins::service] 16 | trait Service { 17 | async fn add(x: i32, y: i32) -> i32; 18 | async fn hey(name: String) -> String; 19 | } 20 | 21 | #[derive(Clone)] 22 | struct Server; 23 | 24 | impl Service for Server { 25 | async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { 26 | x + y 27 | } 28 | 29 | async fn hey(self, _: context::Context, name: String) -> String { 30 | format!("Hey, {name}.") 31 | } 32 | } 33 | 34 | #[tokio::test] 35 | async fn sequential() { 36 | let (tx, rx) = tarpc::transport::channel::unbounded(); 37 | let client = client::new(client::Config::default(), tx).spawn(); 38 | let channel = BaseChannel::with_defaults(rx); 39 | tokio::spawn( 40 | channel 41 | .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) })) 42 | .for_each(|response| response), 43 | ); 44 | assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); 45 | } 46 | 47 | #[tokio::test] 48 | async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { 49 | #[tarpc_plugins::service] 50 | trait Loop { 51 | async fn r#loop(); 52 | } 53 | 54 | #[derive(Clone)] 55 | struct LoopServer; 56 | 57 | impl Loop for LoopServer { 58 | async fn r#loop(self, _: context::Context) { 59 | loop { 60 | futures::pending!(); 61 | } 62 | } 63 | } 64 | 65 | let _ = tracing_subscriber::fmt::try_init(); 66 | 67 | let (tx, rx) = channel::unbounded(); 68 | 69 | // Set up a client that initiates a long-lived request. 70 | // The request will complete in error when the server drops the connection. 71 | tokio::spawn(async move { 72 | let client = LoopClient::new(client::Config::default(), tx).spawn(); 73 | 74 | let mut ctx = context::current(); 75 | ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); 76 | let _ = client.r#loop(ctx).await; 77 | }); 78 | 79 | let mut requests = BaseChannel::with_defaults(rx).requests(); 80 | // Reading a request should trigger the request being registered with BaseChannel. 81 | let first_request = requests.next().await.unwrap()?; 82 | // Dropping the channel should trigger cleanup of outstanding requests. 83 | drop(requests); 84 | // In-flight requests should be aborted by channel cleanup. 85 | // The first and only request sent by the client is `loop`, which is an infinite loop 86 | // on the server side, so if cleanup was not triggered, this line should hang indefinitely. 87 | first_request.execute(LoopServer.serve()).await; 88 | 89 | Ok(()) 90 | } 91 | 92 | #[cfg(all(feature = "serde-transport", feature = "tcp"))] 93 | #[tokio::test] 94 | async fn serde_tcp() -> anyhow::Result<()> { 95 | use tarpc::serde_transport; 96 | use tokio_serde::formats::Json; 97 | 98 | let _ = tracing_subscriber::fmt::try_init(); 99 | 100 | let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?; 101 | let addr = transport.local_addr(); 102 | tokio::spawn( 103 | transport 104 | .take(1) 105 | .filter_map(|r| async { r.ok() }) 106 | .map(BaseChannel::with_defaults) 107 | .execute(Server.serve()) 108 | .map(|channel| channel.for_each(spawn)) 109 | .for_each(spawn), 110 | ); 111 | 112 | let transport = serde_transport::tcp::connect(addr, Json::default).await?; 113 | let client = ServiceClient::new(client::Config::default(), transport).spawn(); 114 | 115 | assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); 116 | assert_matches!( 117 | client.hey(context::current(), "Tim".to_string()).await, 118 | Ok(ref s) if s == "Hey, Tim." 119 | ); 120 | 121 | Ok(()) 122 | } 123 | 124 | #[cfg(all(feature = "serde-transport", feature = "unix", unix))] 125 | #[tokio::test] 126 | async fn serde_uds() -> anyhow::Result<()> { 127 | use tarpc::serde_transport; 128 | use tokio_serde::formats::Json; 129 | 130 | let _ = tracing_subscriber::fmt::try_init(); 131 | 132 | let sock = tarpc::serde_transport::unix::TempPathBuf::with_random("uds"); 133 | let transport = tarpc::serde_transport::unix::listen(&sock, Json::default).await?; 134 | tokio::spawn( 135 | transport 136 | .take(1) 137 | .filter_map(|r| async { r.ok() }) 138 | .map(BaseChannel::with_defaults) 139 | .execute(Server.serve()) 140 | .map(|channel| channel.for_each(spawn)) 141 | .for_each(spawn), 142 | ); 143 | 144 | let transport = serde_transport::unix::connect(&sock, Json::default).await?; 145 | let client = ServiceClient::new(client::Config::default(), transport).spawn(); 146 | 147 | // Save results using socket so we can clean the socket even if our test assertions fail 148 | let res1 = client.add(context::current(), 1, 2).await; 149 | let res2 = client.hey(context::current(), "Tim".to_string()).await; 150 | 151 | assert_matches!(res1, Ok(3)); 152 | assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); 153 | 154 | Ok(()) 155 | } 156 | 157 | #[tokio::test] 158 | async fn concurrent() -> anyhow::Result<()> { 159 | let _ = tracing_subscriber::fmt::try_init(); 160 | 161 | let (tx, rx) = channel::unbounded(); 162 | tokio::spawn( 163 | stream::once(ready(rx)) 164 | .map(BaseChannel::with_defaults) 165 | .execute(Server.serve()) 166 | .map(|channel| channel.for_each(spawn)) 167 | .for_each(spawn), 168 | ); 169 | 170 | let client = ServiceClient::new(client::Config::default(), tx).spawn(); 171 | 172 | let req1 = client.add(context::current(), 1, 2); 173 | let req2 = client.add(context::current(), 3, 4); 174 | let req3 = client.hey(context::current(), "Tim".to_string()); 175 | 176 | assert_matches!(req1.await, Ok(3)); 177 | assert_matches!(req2.await, Ok(7)); 178 | assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim."); 179 | 180 | Ok(()) 181 | } 182 | 183 | #[tokio::test] 184 | async fn concurrent_join() -> anyhow::Result<()> { 185 | let _ = tracing_subscriber::fmt::try_init(); 186 | 187 | let (tx, rx) = channel::unbounded(); 188 | tokio::spawn( 189 | stream::once(ready(rx)) 190 | .map(BaseChannel::with_defaults) 191 | .execute(Server.serve()) 192 | .map(|channel| channel.for_each(spawn)) 193 | .for_each(spawn), 194 | ); 195 | 196 | let client = ServiceClient::new(client::Config::default(), tx).spawn(); 197 | 198 | let req1 = client.add(context::current(), 1, 2); 199 | let req2 = client.add(context::current(), 3, 4); 200 | let req3 = client.hey(context::current(), "Tim".to_string()); 201 | 202 | let (resp1, resp2, resp3) = join!(req1, req2, req3); 203 | assert_matches!(resp1, Ok(3)); 204 | assert_matches!(resp2, Ok(7)); 205 | assert_matches!(resp3, Ok(ref s) if s == "Hey, Tim."); 206 | 207 | Ok(()) 208 | } 209 | 210 | #[cfg(test)] 211 | async fn spawn(fut: impl Future + Send + 'static) { 212 | tokio::spawn(fut); 213 | } 214 | 215 | #[tokio::test] 216 | async fn concurrent_join_all() -> anyhow::Result<()> { 217 | let _ = tracing_subscriber::fmt::try_init(); 218 | 219 | let (tx, rx) = channel::unbounded(); 220 | tokio::spawn( 221 | BaseChannel::with_defaults(rx) 222 | .execute(Server.serve()) 223 | .for_each(spawn), 224 | ); 225 | 226 | let client = ServiceClient::new(client::Config::default(), tx).spawn(); 227 | 228 | let req1 = client.add(context::current(), 1, 2); 229 | let req2 = client.add(context::current(), 3, 4); 230 | 231 | let responses = join_all(vec![req1, req2]).await; 232 | assert_matches!(responses[0], Ok(3)); 233 | assert_matches!(responses[1], Ok(7)); 234 | 235 | Ok(()) 236 | } 237 | 238 | #[tokio::test] 239 | async fn counter() -> anyhow::Result<()> { 240 | #[tarpc::service] 241 | trait Counter { 242 | async fn count() -> u32; 243 | } 244 | 245 | struct CountService(u32); 246 | 247 | impl Counter for &mut CountService { 248 | async fn count(self, _: context::Context) -> u32 { 249 | self.0 += 1; 250 | self.0 251 | } 252 | } 253 | 254 | let (tx, rx) = channel::unbounded(); 255 | tokio::spawn(async { 256 | let mut requests = BaseChannel::with_defaults(rx).requests(); 257 | let mut counter = CountService(0); 258 | 259 | while let Some(Ok(request)) = requests.next().await { 260 | request.execute(counter.serve()).await; 261 | } 262 | }); 263 | 264 | let client = CounterClient::new(client::Config::default(), tx).spawn(); 265 | assert_matches!(client.count(context::current()).await, Ok(1)); 266 | assert_matches!(client.count(context::current()).await, Ok(2)); 267 | 268 | Ok(()) 269 | } 270 | --------------------------------------------------------------------------------