├── .gitignore ├── tests ├── common │ ├── mod.rs │ ├── server.rs │ └── proto.rs ├── test.proto └── lib.rs ├── Cargo.toml ├── src ├── status.rs ├── lib.rs ├── nest.rs └── rest_grpc.rs └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock -------------------------------------------------------------------------------- /tests/common/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod proto; 2 | pub mod server; 3 | -------------------------------------------------------------------------------- /tests/test.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package proto; 3 | 4 | message Test1Request { } 5 | message Test1Reply { } 6 | 7 | message Test2Request { } 8 | message Test2Reply { } 9 | 10 | service Test1 { 11 | rpc test1(Test1Request) returns (Test1Reply); 12 | } 13 | 14 | service Test2 { 15 | rpc test2(Test2Request) returns (Test2Reply); 16 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | 2 | [package] 3 | edition = "2024" 4 | name = "axum_tonic" 5 | version = "0.4.1" 6 | license = "MIT OR Apache-2.0" 7 | description = "Use Tonic with Axum" 8 | repository = "https://github.com/jvdwrf/axum-tonic" 9 | keywords = ["axum", "tonic", "interop", "grpc", "web"] 10 | categories = ["web-programming"] 11 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 12 | 13 | [dependencies] 14 | axum = "0.8" 15 | tonic = "0.13" 16 | hyper = "1" 17 | futures = "0.3" 18 | tower = { version = "0.5", features = ["make"] } 19 | http-body = "1" 20 | 21 | [build-dependencies] 22 | tonic-build = "0.13" 23 | 24 | [dev-dependencies] 25 | tokio = { version = "1", features = ["full"] } 26 | prost = "0.13" 27 | tower-http = { version = "0.6", features = [ 28 | "compression-gzip", 29 | "cors", 30 | "compression-br", 31 | "compression-deflate", 32 | "trace", 33 | ] } 34 | tracing-subscriber = "0.3" 35 | -------------------------------------------------------------------------------- /src/status.rs: -------------------------------------------------------------------------------- 1 | use axum::response::{IntoResponse, Response}; 2 | use std::ops::{Deref, DerefMut}; 3 | 4 | /// A simple wrapper around a `tonic::Status` usable in axum middleware. 5 | /// 6 | /// ## Example 7 | /// ``` 8 | /// use axum::{middleware::{Next, from_fn}, response::Response, Router, extract::Request}; 9 | /// use axum_tonic::GrpcStatus; 10 | /// 11 | /// async fn tonic_middleware( 12 | /// req: Request, 13 | /// next: Next 14 | /// ) -> Result { 15 | /// if is_auth(&req) { 16 | /// Ok(next.run(req).await) 17 | /// } else { 18 | /// Err( 19 | /// tonic::Status::permission_denied("Not authenticated").into() 20 | /// ) 21 | /// } 22 | /// } 23 | /// 24 | /// fn is_auth(req: &Request) -> bool { 25 | /// true // or other logic 26 | /// } 27 | /// 28 | /// let router: Router<()> = Router::new() 29 | /// .layer(from_fn(tonic_middleware)); 30 | /// ``` 31 | #[derive(Debug)] 32 | pub struct GrpcStatus(pub tonic::Status); 33 | 34 | impl From for GrpcStatus { 35 | fn from(s: tonic::Status) -> Self { 36 | Self(s) 37 | } 38 | } 39 | 40 | impl Deref for GrpcStatus { 41 | type Target = tonic::Status; 42 | 43 | fn deref(&self) -> &Self::Target { 44 | &self.0 45 | } 46 | } 47 | 48 | impl DerefMut for GrpcStatus { 49 | fn deref_mut(&mut self) -> &mut Self::Target { 50 | &mut self.0 51 | } 52 | } 53 | 54 | impl IntoResponse for GrpcStatus { 55 | fn into_response(self) -> Response { 56 | self.0.into_http::().into_response() 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /tests/common/server.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Mutex; 2 | 3 | use crate::common::proto::test1_server::*; 4 | use crate::common::proto::test2_server::*; 5 | use crate::common::proto::*; 6 | use tonic::Response; 7 | use tonic::async_trait; 8 | 9 | pub struct Test1Service { 10 | pub state: Mutex, 11 | pub str: String, 12 | } 13 | 14 | #[async_trait] 15 | impl Test1 for Test1Service { 16 | async fn test1( 17 | &self, 18 | _request: tonic::Request, 19 | ) -> Result, tonic::Status> { 20 | *self.state.lock().unwrap() += 5; 21 | 22 | println!("{}", self.state.lock().unwrap().clone()); 23 | 24 | Ok(Response::new(Test1Reply {})) 25 | } 26 | } 27 | 28 | pub struct Test2Service; 29 | #[async_trait] 30 | impl Test2 for Test2Service { 31 | async fn test2( 32 | &self, 33 | _request: tonic::Request, 34 | ) -> Result, tonic::Status> { 35 | Ok(Response::new(Test2Reply {})) 36 | } 37 | } 38 | 39 | pub struct Test1ServiceWithConnectInfo { 40 | pub state: Mutex, 41 | pub str: String, 42 | } 43 | 44 | #[async_trait] 45 | impl Test1 for Test1ServiceWithConnectInfo { 46 | async fn test1( 47 | &self, 48 | request: tonic::Request, 49 | ) -> Result, tonic::Status> { 50 | *self.state.lock().unwrap() += 5; 51 | 52 | println!("{}", self.state.lock().unwrap().clone()); 53 | if request.remote_addr().is_some() { 54 | Ok(Response::new(Test1Reply {})) 55 | } else { 56 | Err(tonic::Status::internal("connect info error")) 57 | } 58 | } 59 | } 60 | 61 | pub struct Test2ServiceWithConnectInfo; 62 | #[async_trait] 63 | impl Test2 for Test2ServiceWithConnectInfo { 64 | async fn test2( 65 | &self, 66 | request: tonic::Request, 67 | ) -> Result, tonic::Status> { 68 | if request.remote_addr().is_some() { 69 | Ok(Response::new(Test2Reply {})) 70 | } else { 71 | Err(tonic::Status::internal("connect info error")) 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! ```ignore 2 | //! /// A middleware that does nothing, but just passes on the request. 3 | //! async fn do_nothing_middleware(req: Request, next: Next) -> Result { 4 | //! Ok(next.run(req).await) 5 | //! } 6 | //! 7 | //! /// A middleware that cancels the request with a grpc status-code 8 | //! async fn cancel_request_middleware(_req: Request, _next: Next) -> Result { 9 | //! Err(tonic::Status::cancelled("Canceled").into()) 10 | //! } 11 | //! 12 | //! #[tokio::main] 13 | //! async fn main() { 14 | //! 15 | //! // Spawn the Server 16 | //! tokio::task::spawn(async move { 17 | //! // The first grpc-service has middleware that accepts the request. 18 | //! let grpc_router1 = Router::new() 19 | //! .nest_tonic(Test1Server::new(Test1Service)) 20 | //! .layer(from_fn(do_nothing_middleware)); 21 | //! 22 | //! // The second grpc-service instead cancels the request 23 | //! let grpc_router2 = Router::new() 24 | //! .nest_tonic(Test2Server::new(Test2Service)) 25 | //! .layer(from_fn(cancel_request_middleware)); 26 | //! 27 | //! // Merge both routers into one. 28 | //! let grpc_router = grpc_router1.merge(grpc_router2); 29 | //! 30 | //! // This is the normal rest-router, to which all normal requests are routed 31 | //! let rest_router = Router::new() 32 | //! .nest("/", Router::new().route("/123", get(|| async move {}))) 33 | //! .route("/", get(|| async move {})); 34 | //! 35 | //! // Combine both services into one 36 | //! let service = RestGrpcService::new(rest_router, grpc_router).into_make_service(); 37 | //! 38 | //! // And serve at 127.0.0.1:8080 39 | //! let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap(); 40 | //! axum::serve(listener, service) 41 | //! .await 42 | //! .unwrap(); 43 | //! }); 44 | //! 45 | //! tokio::time::sleep(Duration::from_millis(100)).await; 46 | //! 47 | //! // Connect to the server with a grpc-client 48 | //! let channel = Channel::from_static("http://127.0.0.1:8080") 49 | //! .connect() 50 | //! .await 51 | //! .unwrap(); 52 | //! 53 | //! let mut client1 = Test1Client::new(channel.clone()); 54 | //! let mut client2 = Test2Client::new(channel); 55 | //! 56 | //! // The first request will succeed 57 | //! client1.test1(Test1Request {}).await.unwrap(); 58 | //! 59 | //! // While the second one gives a grpc Status::Canceled code. 60 | //! assert_eq!( 61 | //! client2.test2(Test2Request {}).await.unwrap_err().code(), 62 | //! tonic::Code::Cancelled, 63 | //! ); 64 | //! } 65 | //! ``` 66 | 67 | mod nest; 68 | mod rest_grpc; 69 | mod status; 70 | 71 | pub use nest::NestTonic; 72 | pub use rest_grpc::RestGrpcService; 73 | pub use status::GrpcStatus; 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # axum_tonic 2 | 3 | [![Crates.io](https://img.shields.io/crates/v/axum_tonic)](https://crates.io/crates/axum_tonic) 4 | [![Documentation](https://docs.rs/axum_tonic/badge.svg)](https://docs.rs/axum_tonic) 5 | 6 | A tiny crate to use Tonic with Axum. 7 | 8 | This crate makes it simple to use different kinds of middleware with different tonic-services. 9 | 10 | The recommended way to use this is to create two separate root-routers, one for grpc and one for rest. Then both can be combined together at the root, and turned into a make service. 11 | 12 | See the docs of Axum or Tonic for more information about the respective frameworks. 13 | 14 | ## Example 15 | 16 | ```rust 17 | 18 | /// A middleware that does nothing, but just passes on the request. 19 | async fn do_nothing_middleware(req: Request, next: Next) -> Result { 20 | Ok(next.run(req).await) 21 | } 22 | 23 | /// A middleware that cancels the request with a grpc status-code 24 | async fn cancel_request_middleware(_req: Request, _next: Next) -> Result { 25 | Err(tonic::Status::cancelled("Canceled").into()) 26 | } 27 | 28 | #[tokio::main] 29 | async fn main() { 30 | 31 | // Spawn the Server 32 | tokio::task::spawn(async move { 33 | // The first grpc-service has middleware that accepts the request. 34 | let grpc_router1 = Router::new() 35 | .nest_tonic(Test1Server::new(Test1Service)) 36 | .layer(from_fn(do_nothing_middleware)); 37 | 38 | // The second grpc-service instead cancels the request 39 | let grpc_router2 = Router::new() 40 | .nest_tonic(Test2Server::new(Test2Service)) 41 | .layer(from_fn(cancel_request_middleware)); 42 | 43 | // Merge both routers into one. 44 | let grpc_router = grpc_router1.merge(grpc_router2); 45 | 46 | // This is the normal rest-router, to which all normal requests are routed 47 | let rest_router = Router::new() 48 | .nest("/", Router::new().route("/123", get(|| async move {}))) 49 | .route("/", get(|| async move {})); 50 | 51 | // Combine both services into one 52 | let service = RestGrpcService::new(rest_router, grpc_router); 53 | 54 | // And serve at 127.0.0.1:8080 55 | axum::Server::bind(&"127.0.0.1:8080".parse().unwrap()) 56 | .serve(service.into_make_service()) 57 | .await 58 | .unwrap(); 59 | }); 60 | 61 | tokio::time::sleep(Duration::from_millis(100)).await; 62 | 63 | // Connect to the server with a grpc-client 64 | let channel = Channel::from_static("http://127.0.0.1:8080") 65 | .connect() 66 | .await 67 | .unwrap(); 68 | 69 | let mut client1 = Test1Client::new(channel.clone()); 70 | let mut client2 = Test2Client::new(channel); 71 | 72 | // The first request will succeed 73 | client1.test1(Test1Request {}).await.unwrap(); 74 | 75 | // While the second one gives a grpc Status::Canceled code. 76 | assert_eq!( 77 | client2.test2(Test2Request {}).await.unwrap_err().code(), 78 | tonic::Code::Cancelled, 79 | ); 80 | } 81 | ``` -------------------------------------------------------------------------------- /src/nest.rs: -------------------------------------------------------------------------------- 1 | use std::convert::Infallible; 2 | 3 | use axum::{Router, response::IntoResponse, routing::any_service}; 4 | use futures::{Future, FutureExt}; 5 | use hyper::Request; 6 | use tonic::server::NamedService; 7 | use tower::Service; 8 | 9 | /// This trait automatically nests the NamedService at the correct path. 10 | pub trait NestTonic: Sized { 11 | /// Nest a tonic-service at the root path of this router. 12 | fn nest_tonic(self, svc: S) -> Self 13 | where 14 | S: Service< 15 | hyper::Request, 16 | Error = Infallible, 17 | Response = hyper::Response, 18 | > 19 | + Clone 20 | + Send 21 | + Sync 22 | + 'static 23 | + NamedService, 24 | S::Future: Send + 'static + Unpin, 25 | B: Send + http_body::Body + 'static, 26 | B::Error: Into>; 27 | } 28 | 29 | impl NestTonic for Router { 30 | fn nest_tonic(self, svc: S) -> Self 31 | where 32 | S: Service< 33 | hyper::Request, 34 | Error = Infallible, 35 | Response = hyper::Response, 36 | > 37 | + Clone 38 | + Send 39 | + Sync 40 | + 'static 41 | + NamedService, 42 | S::Future: Send + 'static + Unpin, 43 | B: Send + http_body::Body + 'static, 44 | B::Error: Into>, 45 | { 46 | // Nest it at /S::NAME, and wrap the service in an AxumTonicService 47 | self.route( 48 | &format!("/{}/{{*grpc_service}}", S::NAME), 49 | any_service(AxumTonicService { svc }), 50 | ) 51 | } 52 | } 53 | 54 | //------------------------------------------------------------------------------------------------ 55 | // Service 56 | //------------------------------------------------------------------------------------------------ 57 | 58 | /// The service that converts a tonic service into an axum-compatible one. 59 | #[derive(Clone, Debug)] 60 | struct AxumTonicService { 61 | svc: S, 62 | } 63 | 64 | impl Service> for AxumTonicService 65 | where 66 | S: Service, Error = Infallible, Response = hyper::Response>, 67 | S::Future: Unpin, 68 | TBody: Send + http_body::Body + 'static, 69 | TBody::Error: Into>, 70 | { 71 | type Response = axum::response::Response; 72 | type Error = Infallible; 73 | type Future = AxumTonicServiceFut; 74 | 75 | fn poll_ready( 76 | &mut self, 77 | cx: &mut std::task::Context<'_>, 78 | ) -> std::task::Poll> { 79 | self.svc.poll_ready(cx) 80 | } 81 | 82 | fn call(&mut self, req: Request) -> Self::Future { 83 | AxumTonicServiceFut { 84 | fut: self.svc.call(req), 85 | } 86 | } 87 | } 88 | 89 | //------------------------------------------------------------------------------------------------ 90 | // Future 91 | //------------------------------------------------------------------------------------------------ 92 | 93 | /// The future that is returned by the AxumTonicService 94 | struct AxumTonicServiceFut { 95 | fut: F, 96 | } 97 | 98 | impl Future for AxumTonicServiceFut 99 | where 100 | F: Future, Infallible>> + Unpin, 101 | B: Send + http_body::Body + 'static, 102 | B::Error: Into>, 103 | { 104 | type Output = Result; 105 | 106 | fn poll( 107 | mut self: std::pin::Pin<&mut Self>, 108 | cx: &mut std::task::Context<'_>, 109 | ) -> std::task::Poll { 110 | // we only have to map this, whenever an actual response is returned 111 | self.fut 112 | .poll_unpin(cx) 113 | .map_ok(|response| response.into_response()) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /tests/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod common; 2 | 3 | use std::{net::SocketAddr, sync::Mutex, time::Duration}; 4 | 5 | use axum::{ 6 | Router, 7 | extract::Request, 8 | middleware::{Next, from_fn}, 9 | response::Response, 10 | routing::get, 11 | }; 12 | use axum_tonic::{GrpcStatus, NestTonic, RestGrpcService}; 13 | use common::{ 14 | proto::{ 15 | Test1Request, Test2Request, test1_client::Test1Client, test1_server::Test1Server, 16 | test2_client::Test2Client, test2_server::Test2Server, 17 | }, 18 | server::{Test1Service, Test1ServiceWithConnectInfo, Test2Service, Test2ServiceWithConnectInfo}, 19 | }; 20 | use tokio::net::TcpListener; 21 | use tonic::transport::Channel; 22 | 23 | 24 | async fn do_nothing(req: Request, next: Next) -> Result { 25 | Ok(next.run(req).await) 26 | } 27 | 28 | async fn cancel_request(_req: Request, _next: Next) -> Result { 29 | Err(tonic::Status::cancelled("Canceled").into()) 30 | } 31 | 32 | #[tokio::test] 33 | async fn main() { 34 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 35 | let port = listener.local_addr().unwrap().port(); 36 | let address = Box::leak(Box::new(format!("http://127.0.0.1:{port}"))); 37 | 38 | tokio::task::spawn(async move { 39 | let grpc_router1 = Router::new() 40 | .nest_tonic(Test1Server::new(Test1Service { 41 | state: Mutex::new(10), 42 | str: String::new(), 43 | })) 44 | .layer(from_fn(do_nothing)); 45 | 46 | let grpc_router2 = Router::new() 47 | .nest_tonic(Test2Server::new(Test2Service)) 48 | .layer(from_fn(cancel_request)); 49 | 50 | let grpc_router = grpc_router1.merge(grpc_router2); 51 | 52 | let rest_router = Router::new().merge(Router::new().route("/123", get(|| async move {}))); 53 | 54 | let service = RestGrpcService::new(rest_router, grpc_router).into_make_service(); 55 | 56 | axum::serve(listener, service).await.unwrap(); 57 | }); 58 | 59 | tokio::time::sleep(Duration::from_millis(100)).await; 60 | 61 | let channel = Channel::from_static(address).connect().await.unwrap(); 62 | 63 | let mut client1 = Test1Client::new(channel.clone()); 64 | client1.test1(Test1Request {}).await.unwrap(); 65 | client1.test1(Test1Request {}).await.unwrap(); 66 | client1.test1(Test1Request {}).await.unwrap(); 67 | client1.test1(Test1Request {}).await.unwrap(); 68 | client1.test1(Test1Request {}).await.unwrap(); 69 | 70 | let channel = Channel::from_static(address).connect().await.unwrap(); 71 | 72 | client1.test1(Test1Request {}).await.unwrap(); 73 | client1.test1(Test1Request {}).await.unwrap(); 74 | client1.test1(Test1Request {}).await.unwrap(); 75 | client1.test1(Test1Request {}).await.unwrap(); 76 | client1.test1(Test1Request {}).await.unwrap(); 77 | 78 | let mut client2 = Test2Client::new(channel); 79 | assert_eq!( 80 | client2.test2(Test2Request {}).await.unwrap_err().code(), 81 | tonic::Code::Cancelled, 82 | ); 83 | } 84 | 85 | #[tokio::test] 86 | async fn main_connect_info() { 87 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 88 | let port = listener.local_addr().unwrap().port(); 89 | let address = Box::leak(Box::new(format!("http://127.0.0.1:{port}"))); 90 | 91 | tokio::task::spawn(async move { 92 | let grpc_router1 = Router::new() 93 | .nest_tonic(Test1Server::new(Test1ServiceWithConnectInfo { 94 | state: Mutex::new(10), 95 | str: String::new(), 96 | })) 97 | .layer(from_fn(do_nothing)); 98 | 99 | let grpc_router2 = Router::new() 100 | .nest_tonic(Test2Server::new(Test2ServiceWithConnectInfo)) 101 | .layer(from_fn(cancel_request)); 102 | 103 | let grpc_router = grpc_router1.merge(grpc_router2); 104 | 105 | let rest_router = Router::new().merge(Router::new().route("/123", get(|| async move {}))); 106 | 107 | let service = RestGrpcService::new(rest_router, grpc_router).into_make_service_with_connect_info::(); 108 | 109 | axum::serve(listener, service).await.unwrap(); 110 | }); 111 | 112 | tokio::time::sleep(Duration::from_millis(100)).await; 113 | 114 | let channel = Channel::from_static(address).connect().await.unwrap(); 115 | 116 | let mut client1 = Test1Client::new(channel.clone()); 117 | client1.test1(Test1Request {}).await.unwrap(); 118 | client1.test1(Test1Request {}).await.unwrap(); 119 | client1.test1(Test1Request {}).await.unwrap(); 120 | client1.test1(Test1Request {}).await.unwrap(); 121 | client1.test1(Test1Request {}).await.unwrap(); 122 | 123 | let channel = Channel::from_static(address).connect().await.unwrap(); 124 | 125 | client1.test1(Test1Request {}).await.unwrap(); 126 | client1.test1(Test1Request {}).await.unwrap(); 127 | client1.test1(Test1Request {}).await.unwrap(); 128 | client1.test1(Test1Request {}).await.unwrap(); 129 | client1.test1(Test1Request {}).await.unwrap(); 130 | 131 | let mut client2 = Test2Client::new(channel); 132 | assert_eq!( 133 | client2.test2(Test2Request {}).await.unwrap_err().code(), 134 | tonic::Code::Cancelled, 135 | ); 136 | } 137 | -------------------------------------------------------------------------------- /src/rest_grpc.rs: -------------------------------------------------------------------------------- 1 | use axum::{ 2 | extract::connect_info::{ConnectInfo, Connected}, 3 | http::header::CONTENT_TYPE, 4 | Router 5 | }; 6 | use futures::ready; 7 | use hyper::{Request, Response}; 8 | use std::{ 9 | convert::Infallible, 10 | task::{Context, Poll}, 11 | any::Any, 12 | net::SocketAddr, 13 | }; 14 | use tonic::transport::server::TcpConnectInfo; 15 | use tower::{make::Shared, Service}; 16 | 17 | /// This service splits all incoming requests either to the rest-service, or to 18 | /// the grpc-service based on the `content-type` header. 19 | /// 20 | /// Only if the header `content-type = application/grpc` exists, will the requests be handled 21 | /// by the grpc-service. All other requests go to the rest-service. 22 | #[derive(Debug, Clone)] 23 | pub struct RestGrpcService { 24 | rest_router: Router, 25 | rest_ready: bool, 26 | grpc_router: Router, 27 | grpc_ready: bool, 28 | } 29 | 30 | impl RestGrpcService { 31 | /// Create a new service, which splits requests between the rest- and grpc-router. 32 | pub fn new(rest_router: Router, grpc_router: Router) -> Self { 33 | Self { 34 | rest_router, 35 | rest_ready: false, 36 | grpc_router, 37 | grpc_ready: false, 38 | } 39 | } 40 | 41 | /// Create a make-service from this service. This make-service can be directly used 42 | /// in the `serve` method of an axum/hyper Server. 43 | /// 44 | /// If you would like to add shared middleware for both the rest-service and the grpc-service, 45 | /// the following approach is recommended: 46 | /// 47 | /// ```ignore 48 | /// use axum_tonic::RestGrpcService; 49 | /// use tokio::net::TcpListener; 50 | /// use tower::ServiceBuilder; 51 | /// 52 | /// let svc: RestGrpcService = my_service(); 53 | /// 54 | /// let svc_with_layers = ServiceBuilder::new() 55 | /// .buffer(5) 56 | /// .layer(my_layer1()) 57 | /// .layer(my_layer2()) 58 | /// .service(svc); 59 | /// 60 | /// axum::serve(TcpListener::bind(&"127.0.0.1:3000"), svc_with_layers) 61 | /// .await 62 | /// .unwrap(); 63 | /// ``` 64 | pub fn into_make_service(self) -> Shared { 65 | Shared::new(self) 66 | } 67 | 68 | /// Create a make-service with connect info from this service. 69 | /// This allows you to extract connection information in your handlers using 70 | /// `extract::ConnectInfo`. 71 | /// 72 | /// Example: 73 | /// 74 | /// ```ignore 75 | /// use axum_tonic::RestGrpcService; 76 | /// use axum::extract::ConnectInfo; 77 | /// use tokio::net::TcpListener; 78 | /// use std::net::SocketAddr; 79 | /// 80 | /// // Create a router with a handler that uses connect info 81 | /// let rest_router = axum::Router::new() 82 | /// .route("/", axum::routing::get( 83 | /// |ConnectInfo(addr): ConnectInfo| async move { 84 | /// format!("Hello from IP: {}", addr) 85 | /// } 86 | /// )); 87 | /// 88 | /// let grpc_router = // Your gRPC router 89 | /// 90 | /// let svc = RestGrpcService::new(rest_router, grpc_router); 91 | /// 92 | /// // Use with connect info 93 | /// axum::serve( 94 | /// TcpListener::bind("127.0.0.1:3000").await.unwrap(), 95 | /// svc.into_make_service_with_connect_info::() 96 | /// ).await.unwrap(); 97 | /// ``` 98 | pub fn into_make_service_with_connect_info(self) -> RestGrpcMakeServiceWithConnectInfo 99 | where 100 | C: Send + Sync + Clone + 'static, 101 | { 102 | RestGrpcMakeServiceWithConnectInfo { 103 | inner: self, 104 | _connect_info: std::marker::PhantomData, 105 | } 106 | } 107 | } 108 | 109 | 110 | /// A wrapper service that captures connection info and passes it to the inner service. 111 | #[derive(Clone)] 112 | pub struct RestGrpcMakeServiceWithConnectInfo { 113 | inner: RestGrpcService, 114 | _connect_info: std::marker::PhantomData C>, 115 | } 116 | 117 | impl Service for RestGrpcMakeServiceWithConnectInfo 118 | where 119 | C: Connected + Send + Sync + Clone + 'static, 120 | Target: Send, // Target is typically &'a AddrStream, which is Send but not 'static 121 | { 122 | type Response = RestGrpcServiceWithConnectInfo; 123 | type Error = Infallible; 124 | type Future = std::future::Ready>; 125 | 126 | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { 127 | Poll::Ready(Ok(())) 128 | } 129 | 130 | fn call(&mut self, target: Target) -> Self::Future { 131 | let connect_info = C::connect_info(target); 132 | let inner = self.inner.clone(); 133 | 134 | std::future::ready(Ok(RestGrpcServiceWithConnectInfo { 135 | inner, 136 | connect_info, 137 | })) 138 | } 139 | } 140 | 141 | /// A service that holds both the RestGrpcService and the connection info. 142 | #[derive(Clone)] 143 | pub struct RestGrpcServiceWithConnectInfo { 144 | inner: RestGrpcService, 145 | connect_info: C, 146 | } 147 | 148 | impl Service> for RestGrpcServiceWithConnectInfo 149 | where 150 | C: Send + Sync + Clone + 'static, 151 | ReqBody: http_body::Body + Send + 'static, 152 | ReqBody::Error: Into>, 153 | { 154 | type Response = Response; 155 | type Error = Infallible; 156 | type Future = >>::Future; 157 | 158 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { 159 | // drive readiness for each inner service and record which is ready 160 | loop { 161 | match (self.inner.rest_ready, self.inner.grpc_ready) { 162 | (true, true) => { 163 | return Ok(()).into(); 164 | } 165 | (false, _) => { 166 | ready!( 167 | >>::poll_ready( 168 | &mut self.inner.rest_router, 169 | cx 170 | ) 171 | )?; 172 | self.inner.rest_ready = true; 173 | } 174 | (_, false) => { 175 | ready!( 176 | >>::poll_ready( 177 | &mut self.inner.rest_router, 178 | cx 179 | ) 180 | )?; 181 | self.inner.grpc_ready = true; 182 | } 183 | } 184 | } 185 | } 186 | 187 | fn call(&mut self, mut req: Request) -> Self::Future { 188 | // require users to call `poll_ready` first, if they don't we're allowed to panic 189 | // as per the `tower::Service` contract 190 | assert!( 191 | self.inner.grpc_ready, 192 | "grpc service not ready. Did you forget to call `poll_ready`?" 193 | ); 194 | assert!( 195 | self.inner.rest_ready, 196 | "rest service not ready. Did you forget to call `poll_ready`?" 197 | ); 198 | 199 | // Store connect info in request extensions for all requests passing through 200 | // this service, so Axum handlers in both rest_router and grpc_router can access it. 201 | req.extensions_mut().insert(ConnectInfo(self.connect_info.clone())); 202 | 203 | // tonic via grpc_router should also have a TcpConnectInfo, this will populated 204 | // `request.remote_addr()` but not `.local_addr()`. 205 | if let Some(socket_addr) = (&self.connect_info as &dyn Any).downcast_ref::() { 206 | req.extensions_mut().insert(TcpConnectInfo { 207 | local_addr: None, 208 | remote_addr: Some(*socket_addr), 209 | }); 210 | } 211 | 212 | // if we get a grpc request call the grpc service, otherwise call the rest service 213 | // when calling a service it becomes not-ready so we have drive readiness again 214 | if is_grpc_request(&req) { 215 | self.inner.grpc_ready = false; 216 | self.inner.grpc_router.call(req) 217 | } else { 218 | self.inner.rest_ready = false; 219 | self.inner.rest_router.call(req) 220 | } 221 | } 222 | } 223 | 224 | impl Service> for RestGrpcService 225 | where 226 | ReqBody: http_body::Body + Send + 'static, 227 | ReqBody::Error: Into>, 228 | { 229 | type Response = Response; 230 | type Error = Infallible; 231 | type Future = >>::Future; 232 | 233 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { 234 | // drive readiness for each inner service and record which is ready 235 | loop { 236 | match (self.rest_ready, self.grpc_ready) { 237 | (true, true) => { 238 | return Ok(()).into(); 239 | } 240 | (false, _) => { 241 | ready!( 242 | >>::poll_ready( 243 | &mut self.rest_router, 244 | cx 245 | ) 246 | )?; 247 | self.rest_ready = true; 248 | } 249 | (_, false) => { 250 | ready!( 251 | >>::poll_ready( 252 | &mut self.rest_router, 253 | cx 254 | ) 255 | )?; 256 | self.grpc_ready = true; 257 | } 258 | } 259 | } 260 | } 261 | 262 | fn call(&mut self, req: Request) -> Self::Future { 263 | // require users to call `poll_ready` first, if they don't we're allowed to panic 264 | // as per the `tower::Service` contract 265 | assert!( 266 | self.grpc_ready, 267 | "grpc service not ready. Did you forget to call `poll_ready`?" 268 | ); 269 | assert!( 270 | self.rest_ready, 271 | "rest service not ready. Did you forget to call `poll_ready`?" 272 | ); 273 | 274 | // if we get a grpc request call the grpc service, otherwise call the rest service 275 | // when calling a service it becomes not-ready so we have drive readiness again 276 | if is_grpc_request(&req) { 277 | self.grpc_ready = false; 278 | self.grpc_router.call(req) 279 | } else { 280 | self.rest_ready = false; 281 | self.rest_router.call(req) 282 | } 283 | } 284 | } 285 | 286 | fn is_grpc_request(req: &Request) -> bool { 287 | req.headers() 288 | .get(CONTENT_TYPE) 289 | .map(|content_type| content_type.as_bytes()) 290 | .filter(|content_type| content_type.starts_with(b"application/grpc")) 291 | .is_some() 292 | } 293 | -------------------------------------------------------------------------------- /tests/common/proto.rs: -------------------------------------------------------------------------------- 1 | #[derive(Clone, PartialEq, ::prost::Message)] 2 | pub struct Test1Request {} 3 | #[derive(Clone, PartialEq, ::prost::Message)] 4 | pub struct Test1Reply {} 5 | #[derive(Clone, PartialEq, ::prost::Message)] 6 | pub struct Test2Request {} 7 | #[derive(Clone, PartialEq, ::prost::Message)] 8 | pub struct Test2Reply {} 9 | /// Generated client implementations. 10 | pub mod test1_client { 11 | #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] 12 | use tonic::codegen::http::Uri; 13 | use tonic::codegen::*; 14 | #[derive(Debug, Clone)] 15 | pub struct Test1Client { 16 | inner: tonic::client::Grpc, 17 | } 18 | impl Test1Client { 19 | /// Attempt to create a new client by connecting to a given endpoint. 20 | pub async fn connect(dst: D) -> Result 21 | where 22 | D: std::convert::TryInto, 23 | D::Error: Into, 24 | { 25 | let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; 26 | Ok(Self::new(conn)) 27 | } 28 | } 29 | impl Test1Client 30 | where 31 | T: tonic::client::GrpcService, 32 | T::Error: Into, 33 | T::ResponseBody: Body + Send + 'static, 34 | ::Error: Into + Send, 35 | { 36 | pub fn new(inner: T) -> Self { 37 | let inner = tonic::client::Grpc::new(inner); 38 | Self { inner } 39 | } 40 | pub fn with_origin(inner: T, origin: Uri) -> Self { 41 | let inner = tonic::client::Grpc::with_origin(inner, origin); 42 | Self { inner } 43 | } 44 | pub fn with_interceptor( 45 | inner: T, 46 | interceptor: F, 47 | ) -> Test1Client> 48 | where 49 | F: tonic::service::Interceptor, 50 | T::ResponseBody: Default, 51 | T: tonic::codegen::Service< 52 | http::Request, 53 | Response = http::Response< 54 | >::ResponseBody, 55 | >, 56 | >, 57 | >>::Error: 58 | Into + Send + Sync, 59 | { 60 | Test1Client::new(InterceptedService::new(inner, interceptor)) 61 | } 62 | /// Compress requests with the given encoding. 63 | /// 64 | /// This requires the server to support it otherwise it might respond with an 65 | /// error. 66 | #[must_use] 67 | pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { 68 | self.inner = self.inner.send_compressed(encoding); 69 | self 70 | } 71 | /// Enable decompressing responses. 72 | #[must_use] 73 | pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { 74 | self.inner = self.inner.accept_compressed(encoding); 75 | self 76 | } 77 | pub async fn test1( 78 | &mut self, 79 | request: impl tonic::IntoRequest, 80 | ) -> Result, tonic::Status> { 81 | self.inner.ready().await.map_err(|e| { 82 | tonic::Status::new( 83 | tonic::Code::Unknown, 84 | format!("Service was not ready: {}", e.into()), 85 | ) 86 | })?; 87 | let codec = tonic::codec::ProstCodec::default(); 88 | let path = http::uri::PathAndQuery::from_static("/proto.Test1/test1"); 89 | self.inner.unary(request.into_request(), path, codec).await 90 | } 91 | } 92 | } 93 | /// Generated client implementations. 94 | pub mod test2_client { 95 | #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] 96 | use tonic::codegen::http::Uri; 97 | use tonic::codegen::*; 98 | #[derive(Debug, Clone)] 99 | pub struct Test2Client { 100 | inner: tonic::client::Grpc, 101 | } 102 | impl Test2Client { 103 | /// Attempt to create a new client by connecting to a given endpoint. 104 | pub async fn connect(dst: D) -> Result 105 | where 106 | D: std::convert::TryInto, 107 | D::Error: Into, 108 | { 109 | let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; 110 | Ok(Self::new(conn)) 111 | } 112 | } 113 | impl Test2Client 114 | where 115 | T: tonic::client::GrpcService, 116 | T::Error: Into, 117 | T::ResponseBody: Body + Send + 'static, 118 | ::Error: Into + Send, 119 | { 120 | pub fn new(inner: T) -> Self { 121 | let inner = tonic::client::Grpc::new(inner); 122 | Self { inner } 123 | } 124 | pub fn with_origin(inner: T, origin: Uri) -> Self { 125 | let inner = tonic::client::Grpc::with_origin(inner, origin); 126 | Self { inner } 127 | } 128 | pub fn with_interceptor( 129 | inner: T, 130 | interceptor: F, 131 | ) -> Test2Client> 132 | where 133 | F: tonic::service::Interceptor, 134 | T::ResponseBody: Default, 135 | T: tonic::codegen::Service< 136 | http::Request, 137 | Response = http::Response< 138 | >::ResponseBody, 139 | >, 140 | >, 141 | >>::Error: 142 | Into + Send + Sync, 143 | { 144 | Test2Client::new(InterceptedService::new(inner, interceptor)) 145 | } 146 | /// Compress requests with the given encoding. 147 | /// 148 | /// This requires the server to support it otherwise it might respond with an 149 | /// error. 150 | #[must_use] 151 | pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { 152 | self.inner = self.inner.send_compressed(encoding); 153 | self 154 | } 155 | /// Enable decompressing responses. 156 | #[must_use] 157 | pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { 158 | self.inner = self.inner.accept_compressed(encoding); 159 | self 160 | } 161 | pub async fn test2( 162 | &mut self, 163 | request: impl tonic::IntoRequest, 164 | ) -> Result, tonic::Status> { 165 | self.inner.ready().await.map_err(|e| { 166 | tonic::Status::new( 167 | tonic::Code::Unknown, 168 | format!("Service was not ready: {}", e.into()), 169 | ) 170 | })?; 171 | let codec = tonic::codec::ProstCodec::default(); 172 | let path = http::uri::PathAndQuery::from_static("/proto.Test2/test2"); 173 | self.inner.unary(request.into_request(), path, codec).await 174 | } 175 | } 176 | } 177 | /// Generated server implementations. 178 | pub mod test1_server { 179 | #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] 180 | use tonic::codegen::*; 181 | ///Generated trait containing gRPC methods that should be implemented for use with Test1Server. 182 | #[async_trait] 183 | pub trait Test1: Send + Sync + 'static { 184 | async fn test1( 185 | &self, 186 | request: tonic::Request, 187 | ) -> Result, tonic::Status>; 188 | } 189 | #[derive(Debug)] 190 | pub struct Test1Server { 191 | inner: _Inner, 192 | accept_compression_encodings: EnabledCompressionEncodings, 193 | send_compression_encodings: EnabledCompressionEncodings, 194 | } 195 | struct _Inner(Arc); 196 | impl Test1Server { 197 | pub fn new(inner: T) -> Self { 198 | Self::from_arc(Arc::new(inner)) 199 | } 200 | pub fn from_arc(inner: Arc) -> Self { 201 | let inner = _Inner(inner); 202 | Self { 203 | inner, 204 | accept_compression_encodings: Default::default(), 205 | send_compression_encodings: Default::default(), 206 | } 207 | } 208 | pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService 209 | where 210 | F: tonic::service::Interceptor, 211 | { 212 | InterceptedService::new(Self::new(inner), interceptor) 213 | } 214 | /// Enable decompressing requests with the given encoding. 215 | #[must_use] 216 | pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { 217 | self.accept_compression_encodings.enable(encoding); 218 | self 219 | } 220 | /// Compress responses with the given encoding, if the client supports it. 221 | #[must_use] 222 | pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { 223 | self.send_compression_encodings.enable(encoding); 224 | self 225 | } 226 | } 227 | impl tonic::codegen::Service> for Test1Server 228 | where 229 | T: Test1, 230 | B: Body + Send + 'static, 231 | B::Error: Into + Send + 'static, 232 | { 233 | type Response = http::Response; 234 | type Error = std::convert::Infallible; 235 | type Future = BoxFuture; 236 | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { 237 | Poll::Ready(Ok(())) 238 | } 239 | fn call(&mut self, req: http::Request) -> Self::Future { 240 | let inner = self.inner.clone(); 241 | match req.uri().path() { 242 | "/proto.Test1/test1" => { 243 | #[allow(non_camel_case_types)] 244 | struct test1Svc(pub Arc); 245 | impl tonic::server::UnaryService for test1Svc { 246 | type Response = super::Test1Reply; 247 | type Future = BoxFuture, tonic::Status>; 248 | fn call( 249 | &mut self, 250 | request: tonic::Request, 251 | ) -> Self::Future { 252 | let inner = self.0.clone(); 253 | let fut = async move { (*inner).test1(request).await }; 254 | Box::pin(fut) 255 | } 256 | } 257 | let accept_compression_encodings = self.accept_compression_encodings; 258 | let send_compression_encodings = self.send_compression_encodings; 259 | let inner = self.inner.clone(); 260 | let fut = async move { 261 | let inner = inner.0; 262 | let method = test1Svc(inner); 263 | let codec = tonic::codec::ProstCodec::default(); 264 | let mut grpc = tonic::server::Grpc::new(codec).apply_compression_config( 265 | accept_compression_encodings, 266 | send_compression_encodings, 267 | ); 268 | let res = grpc.unary(method, req).await; 269 | Ok(res) 270 | }; 271 | Box::pin(fut) 272 | } 273 | _ => Box::pin(async move { 274 | Ok(http::Response::builder() 275 | .status(200) 276 | .header("grpc-status", "12") 277 | .header("content-type", "application/grpc") 278 | .body(tonic::body::Body::empty()) 279 | .unwrap()) 280 | }), 281 | } 282 | } 283 | } 284 | impl Clone for Test1Server { 285 | fn clone(&self) -> Self { 286 | let inner = self.inner.clone(); 287 | Self { 288 | inner, 289 | accept_compression_encodings: self.accept_compression_encodings, 290 | send_compression_encodings: self.send_compression_encodings, 291 | } 292 | } 293 | } 294 | impl Clone for _Inner { 295 | fn clone(&self) -> Self { 296 | Self(self.0.clone()) 297 | } 298 | } 299 | impl std::fmt::Debug for _Inner { 300 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 301 | write!(f, "{:?}", self.0) 302 | } 303 | } 304 | impl tonic::server::NamedService for Test1Server { 305 | const NAME: &'static str = "proto.Test1"; 306 | } 307 | } 308 | /// Generated server implementations. 309 | pub mod test2_server { 310 | #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] 311 | use tonic::codegen::*; 312 | ///Generated trait containing gRPC methods that should be implemented for use with Test2Server. 313 | #[async_trait] 314 | pub trait Test2: Send + Sync + 'static { 315 | async fn test2( 316 | &self, 317 | request: tonic::Request, 318 | ) -> Result, tonic::Status>; 319 | } 320 | #[derive(Debug)] 321 | pub struct Test2Server { 322 | inner: _Inner, 323 | accept_compression_encodings: EnabledCompressionEncodings, 324 | send_compression_encodings: EnabledCompressionEncodings, 325 | } 326 | struct _Inner(Arc); 327 | impl Test2Server { 328 | pub fn new(inner: T) -> Self { 329 | Self::from_arc(Arc::new(inner)) 330 | } 331 | pub fn from_arc(inner: Arc) -> Self { 332 | let inner = _Inner(inner); 333 | Self { 334 | inner, 335 | accept_compression_encodings: Default::default(), 336 | send_compression_encodings: Default::default(), 337 | } 338 | } 339 | pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService 340 | where 341 | F: tonic::service::Interceptor, 342 | { 343 | InterceptedService::new(Self::new(inner), interceptor) 344 | } 345 | /// Enable decompressing requests with the given encoding. 346 | #[must_use] 347 | pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { 348 | self.accept_compression_encodings.enable(encoding); 349 | self 350 | } 351 | /// Compress responses with the given encoding, if the client supports it. 352 | #[must_use] 353 | pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { 354 | self.send_compression_encodings.enable(encoding); 355 | self 356 | } 357 | } 358 | impl tonic::codegen::Service> for Test2Server 359 | where 360 | T: Test2, 361 | B: Body + Send + 'static, 362 | B::Error: Into + Send + 'static, 363 | { 364 | type Response = http::Response; 365 | type Error = std::convert::Infallible; 366 | type Future = BoxFuture; 367 | fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { 368 | Poll::Ready(Ok(())) 369 | } 370 | fn call(&mut self, req: http::Request) -> Self::Future { 371 | let inner = self.inner.clone(); 372 | match req.uri().path() { 373 | "/proto.Test2/test2" => { 374 | #[allow(non_camel_case_types)] 375 | struct test2Svc(pub Arc); 376 | impl tonic::server::UnaryService for test2Svc { 377 | type Response = super::Test2Reply; 378 | type Future = BoxFuture, tonic::Status>; 379 | fn call( 380 | &mut self, 381 | request: tonic::Request, 382 | ) -> Self::Future { 383 | let inner = self.0.clone(); 384 | let fut = async move { (*inner).test2(request).await }; 385 | Box::pin(fut) 386 | } 387 | } 388 | let accept_compression_encodings = self.accept_compression_encodings; 389 | let send_compression_encodings = self.send_compression_encodings; 390 | let inner = self.inner.clone(); 391 | let fut = async move { 392 | let inner = inner.0; 393 | let method = test2Svc(inner); 394 | let codec = tonic::codec::ProstCodec::default(); 395 | let mut grpc = tonic::server::Grpc::new(codec).apply_compression_config( 396 | accept_compression_encodings, 397 | send_compression_encodings, 398 | ); 399 | let res = grpc.unary(method, req).await; 400 | Ok(res) 401 | }; 402 | Box::pin(fut) 403 | } 404 | _ => Box::pin(async move { 405 | Ok(http::Response::builder() 406 | .status(200) 407 | .header("grpc-status", "12") 408 | .header("content-type", "application/grpc") 409 | .body(tonic::body::Body::empty()) 410 | .unwrap()) 411 | }), 412 | } 413 | } 414 | } 415 | impl Clone for Test2Server { 416 | fn clone(&self) -> Self { 417 | let inner = self.inner.clone(); 418 | Self { 419 | inner, 420 | accept_compression_encodings: self.accept_compression_encodings, 421 | send_compression_encodings: self.send_compression_encodings, 422 | } 423 | } 424 | } 425 | impl Clone for _Inner { 426 | fn clone(&self) -> Self { 427 | Self(self.0.clone()) 428 | } 429 | } 430 | impl std::fmt::Debug for _Inner { 431 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 432 | write!(f, "{:?}", self.0) 433 | } 434 | } 435 | impl tonic::server::NamedService for Test2Server { 436 | const NAME: &'static str = "proto.Test2"; 437 | } 438 | } 439 | --------------------------------------------------------------------------------