├── .github └── workflows │ └── CI.yml ├── .gitignore ├── CHANGELOG.md ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── client.rs ├── server.rs └── server_graceful.rs ├── src ├── client │ ├── client.rs │ ├── legacy │ │ ├── client.rs │ │ ├── connect │ │ │ ├── capture.rs │ │ │ ├── dns.rs │ │ │ ├── http.rs │ │ │ ├── mod.rs │ │ │ └── proxy │ │ │ │ ├── mod.rs │ │ │ │ ├── socks │ │ │ │ ├── mod.rs │ │ │ │ ├── v4 │ │ │ │ │ ├── errors.rs │ │ │ │ │ ├── messages.rs │ │ │ │ │ └── mod.rs │ │ │ │ └── v5 │ │ │ │ │ ├── errors.rs │ │ │ │ │ ├── messages.rs │ │ │ │ │ └── mod.rs │ │ │ │ └── tunnel.rs │ │ ├── mod.rs │ │ └── pool.rs │ ├── mod.rs │ ├── proxy │ │ ├── matcher.rs │ │ └── mod.rs │ └── service.rs ├── common │ ├── exec.rs │ ├── future.rs │ ├── lazy.rs │ ├── mod.rs │ ├── rewind.rs │ ├── sync.rs │ └── timer.rs ├── error.rs ├── lib.rs ├── rt │ ├── io.rs │ ├── mod.rs │ ├── tokio.rs │ └── tokio │ │ ├── with_hyper_io.rs │ │ └── with_tokio_io.rs ├── server │ ├── conn │ │ ├── auto │ │ │ ├── mod.rs │ │ │ └── upgrade.rs │ │ └── mod.rs │ ├── graceful.rs │ └── mod.rs └── service │ ├── glue.rs │ ├── mod.rs │ └── oneshot.rs └── tests ├── legacy_client.rs ├── proxy.rs └── test_utils └── mod.rs /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - master 7 | 8 | env: 9 | RUST_BACKTRACE: 1 10 | 11 | jobs: 12 | ci-pass: 13 | name: CI is green 14 | runs-on: ubuntu-latest 15 | needs: 16 | - style 17 | - test 18 | - msrv 19 | - miri 20 | - features 21 | - semver 22 | - doc 23 | steps: 24 | - run: exit 0 25 | 26 | style: 27 | name: Check Style 28 | runs-on: ubuntu-latest 29 | steps: 30 | - uses: actions/checkout@v4 31 | - uses: dtolnay/rust-toolchain@stable 32 | with: 33 | components: rustfmt 34 | - run: cargo fmt --all --check 35 | 36 | test: 37 | name: Test ${{ matrix.rust }} on ${{ matrix.os }} 38 | needs: [style] 39 | strategy: 40 | matrix: 41 | rust: 42 | - stable 43 | - beta 44 | - nightly 45 | os: 46 | - ubuntu-latest 47 | - windows-latest 48 | - macos-latest 49 | runs-on: ${{ matrix.os }} 50 | steps: 51 | - uses: actions/checkout@v4 52 | - name: Install Rust (${{ matrix.rust }}) 53 | uses: dtolnay/rust-toolchain@master 54 | with: 55 | toolchain: ${{ matrix.rust }} 56 | - run: cargo test --all-features 57 | 58 | msrv: 59 | name: Check MSRV (${{ matrix.rust }}) 60 | needs: [style] 61 | strategy: 62 | matrix: 63 | rust: [ 1.63 ] # keep in sync with 'rust-version' in Cargo.toml 64 | runs-on: ubuntu-latest 65 | steps: 66 | - uses: actions/checkout@v4 67 | - uses: dtolnay/rust-toolchain@stable 68 | - name: Pin some dependencies for MSRV 69 | run: | 70 | cargo update 71 | cargo update --package tokio --precise 1.38.1 72 | cargo update --package tokio-util --precise 0.7.11 73 | cargo update --package hashbrown --precise 0.15.0 74 | cargo update --package once_cell --precise 1.20.3 75 | - name: Install Rust (${{ matrix.rust }}) 76 | uses: dtolnay/rust-toolchain@master 77 | with: 78 | toolchain: ${{ matrix.rust }} 79 | - run: cargo check --features full 80 | 81 | miri: 82 | name: Test with Miri 83 | needs: [style] 84 | runs-on: ubuntu-latest 85 | steps: 86 | - uses: actions/checkout@v4 87 | - uses: dtolnay/rust-toolchain@nightly 88 | with: 89 | components: miri 90 | - name: Test 91 | env: 92 | # Can't enable tcp feature since Miri does not support the tokio runtime 93 | MIRIFLAGS: "-Zmiri-disable-isolation" 94 | run: cargo miri test --all-features 95 | 96 | features: 97 | name: features 98 | needs: [style] 99 | runs-on: ubuntu-latest 100 | steps: 101 | - uses: actions/checkout@v4 102 | - uses: dtolnay/rust-toolchain@stable 103 | - uses: taiki-e/install-action@cargo-hack 104 | - run: cargo hack --no-dev-deps check --feature-powerset --depth 2 105 | 106 | semver: 107 | name: semver 108 | runs-on: ubuntu-latest 109 | steps: 110 | - uses: actions/checkout@v4 111 | - name: Check semver 112 | uses: obi1kenobi/cargo-semver-checks-action@v2 113 | with: 114 | feature-group: only-explicit-features 115 | features: full 116 | release-type: minor 117 | 118 | doc: 119 | name: Build docs 120 | needs: [style, test] 121 | runs-on: ubuntu-latest 122 | steps: 123 | - uses: actions/checkout@v4 124 | - uses: dtolnay/rust-toolchain@nightly 125 | - run: cargo rustdoc -- --cfg docsrs -D rustdoc::broken-intra-doc-links 126 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 0.1.13 (2025-05-27) 2 | 3 | - Fix `HttpConnector` to always prefer IPv6 addresses first, if happy eyeballs is enabled. 4 | - Fix `legacy::Client` to return better errors if available on the connection. 5 | 6 | # 0.1.12 (2025-05-19) 7 | 8 | - Add `client::legacy::proxy::Tunnel` connector that wraps another connector with HTTP tunneling. 9 | - Add `client::legacy::proxy::{SocksV4, SocksV5}` connectors that wraps another connector with SOCKS. 10 | - Add `client::proxy::matcher::Matcher` type that can use environment variables to match proxy rules. 11 | - Add `server::graceful::Watcher` type that can be sent to watch a connection in another task. 12 | - Add `GracefulShutdown::count()` method to get number of currently watched connections. 13 | - Fix missing `must_use` attributes on `Connection` futures. 14 | - Fix tracing span in GAI resolver that can cause panics. 15 | 16 | 17 | # 0.1.11 (2025-03-31) 18 | 19 | - Add `tracing` crate feature with support in `TokioExecutor`. 20 | - Add `HttpConnector::interface()` support for macOS and Solarish systems. 21 | - Add `rt::WithHyperIo` and `rt::WithTokioIo` combinators. 22 | - Add `auto_date_header()` for auto server builder. 23 | - Add `max_local_error_reset_streams()` for auto server builder. 24 | - Add `ignore_invalid_headers()` for auto server builder. 25 | - Add methods to determine if auto server is configured for HTTP/1 or HTTP/2. 26 | - Implement `Connection` for `UnixStream` and `NamedPipeClient`. 27 | - Fix HTTP/2 websocket requests sent through `legacy::Client`. 28 | 29 | # 0.1.10 (2024-10-28) 30 | 31 | - Add `http2_max_header_list_size(num)` option to legacy client builder. 32 | - Add `set_tcp_user_timeout(dur)` option to legacy `HttpConnector`. 33 | 34 | # 0.1.9 (2024-09-24) 35 | 36 | - Add support for `client::legacy` DNS resolvers to set non-zero ports on returned addresses. 37 | - Fix `client::legacy` wrongly retrying pooled connections that were created successfully but failed immediately after, resulting in a retry loop. 38 | 39 | 40 | # 0.1.8 (2024-09-09) 41 | 42 | - Add `server::conn::auto::upgrade::downcast()` for use with auto connection upgrades. 43 | 44 | # 0.1.7 (2024-08-06) 45 | 46 | - Add `Connected::poison()` to `legacy` client, a port from hyper v0.14.x. 47 | - Add `Error::connect_info()` to `legacy` client, a port from hyper v0.14.x. 48 | 49 | # 0.1.6 (2024-07-01) 50 | 51 | - Add support for AIX operating system to `legacy` client. 52 | - Fix `legacy` client to better use dying pooled connections. 53 | 54 | # 0.1.5 (2024-05-28) 55 | 56 | - Add `server::graceful::GracefulShutdown` helper to coordinate over many connections. 57 | - Add `server::conn::auto::Connection::into_owned()` to unlink lifetime from `Builder`. 58 | - Allow `service` module to be available with only `service` feature enabled. 59 | 60 | # 0.1.4 (2024-05-24) 61 | 62 | - Add `initial_max_send_streams()` to `legacy` client builder 63 | - Add `max_pending_accept_reset_streams()` to `legacy` client builder 64 | - Add `max_headers(usize)` to `auto` server builder 65 | - Add `http1_onl()` and `http2_only()` to `auto` server builder 66 | - Add connection capturing API to `legacy` client 67 | - Add `impl Connection for TokioIo` 68 | - Fix graceful shutdown hanging on reading the HTTP version 69 | 70 | # 0.1.3 (2024-01-31) 71 | 72 | ### Added 73 | 74 | - Add `Error::is_connect()` which returns true if error came from client `Connect`. 75 | - Add timer support to `legacy` pool. 76 | - Add support to enable http1/http2 parts of `auto::Builder` individually. 77 | 78 | ### Fixed 79 | 80 | - Fix `auto` connection so it can handle requests shorter than the h2 preface. 81 | - Fix `legacy::Client` to no longer error when keep-alive is diabled. 82 | 83 | # 0.1.2 (2023-12-20) 84 | 85 | ### Added 86 | 87 | - Add `graceful_shutdown()` method to `auto` connections. 88 | - Add `rt::TokioTimer` type that implements `hyper::rt::Timer`. 89 | - Add `service::TowerToHyperService` adapter, allowing using `tower::Service`s as a `hyper::service::Service`. 90 | - Implement `Clone` for `auto::Builder`. 91 | - Exports `legacy::{Builder, ResponseFuture}`. 92 | 93 | ### Fixed 94 | 95 | - Enable HTTP/1 upgrades on the `legacy::Client`. 96 | - Prevent divide by zero if DNS returns 0 addresses. 97 | 98 | # 0.1.1 (2023-11-17) 99 | 100 | ### Added 101 | 102 | - Make `server-auto` enable the `server` feature. 103 | 104 | ### Fixed 105 | 106 | - Reduce `Send` bounds requirements for `auto` connections. 107 | - Docs: enable all features when generating. 108 | 109 | # 0.1.0 (2023-11-16) 110 | 111 | Initial release. 112 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "hyper-util" 3 | version = "0.1.13" 4 | description = "hyper utilities" 5 | readme = "README.md" 6 | homepage = "https://hyper.rs" 7 | documentation = "https://docs.rs/hyper-util" 8 | repository = "https://github.com/hyperium/hyper-util" 9 | license = "MIT" 10 | authors = ["Sean McArthur "] 11 | keywords = ["http", "hyper", "hyperium"] 12 | categories = ["network-programming", "web-programming::http-client", "web-programming::http-server"] 13 | edition = "2021" 14 | rust-version = "1.63" 15 | 16 | [package.metadata.docs.rs] 17 | features = ["full"] 18 | rustdoc-args = ["--cfg", "docsrs"] 19 | 20 | [dependencies] 21 | base64 = { version = "0.22", optional = true } 22 | bytes = "1.7.1" 23 | futures-channel = { version = "0.3", optional = true } 24 | futures-core = { version = "0.3" } 25 | futures-util = { version = "0.3.16", default-features = false, optional = true } 26 | http = "1.0" 27 | http-body = "1.0.0" 28 | hyper = "1.6.0" 29 | ipnet = { version = "2.9", optional = true } 30 | libc = { version = "0.2", optional = true } 31 | percent-encoding = { version = "2.3", optional = true } 32 | pin-project-lite = "0.2.4" 33 | socket2 = { version = "0.5.9", optional = true, features = ["all"] } 34 | tracing = { version = "0.1", default-features = false, features = ["std"], optional = true } 35 | tokio = { version = "1", optional = true, default-features = false } 36 | tower-service = { version = "0.3", optional = true } 37 | 38 | [dev-dependencies] 39 | hyper = { version = "1.4.0", features = ["full"] } 40 | bytes = "1" 41 | futures-util = { version = "0.3.16", default-features = false, features = ["alloc"] } 42 | http-body-util = "0.1.0" 43 | tokio = { version = "1", features = ["macros", "test-util", "signal"] } 44 | tokio-test = "0.4" 45 | pretty_env_logger = "0.5" 46 | 47 | [target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies] 48 | pnet_datalink = "0.35.0" 49 | 50 | [target.'cfg(target_os = "macos")'.dependencies] 51 | system-configuration = { version = "0.6.1", optional = true } 52 | 53 | [target.'cfg(windows)'.dependencies] 54 | windows-registry = { version = "0.5", optional = true } 55 | 56 | [features] 57 | default = [] 58 | 59 | # Shorthand to enable everything 60 | full = [ 61 | "client", 62 | "client-legacy", 63 | "server", 64 | "server-auto", 65 | "server-graceful", 66 | "service", 67 | "http1", 68 | "http2", 69 | "tokio", 70 | "tracing", 71 | ] 72 | 73 | client = ["hyper/client", "dep:tracing", "dep:futures-channel", "dep:tower-service"] 74 | client-legacy = ["client", "dep:socket2", "tokio/sync", "dep:libc", "dep:futures-util"] 75 | client-proxy = ["client", "dep:base64", "dep:ipnet", "dep:percent-encoding"] 76 | client-proxy-system = ["dep:system-configuration", "dep:windows-registry"] 77 | 78 | server = ["hyper/server"] 79 | server-auto = ["server", "http1", "http2"] 80 | server-graceful = ["server", "tokio/sync"] 81 | 82 | service = ["dep:tower-service"] 83 | 84 | http1 = ["hyper/http1"] 85 | http2 = ["hyper/http2"] 86 | 87 | tokio = ["dep:tokio", "tokio/net", "tokio/rt", "tokio/time"] 88 | 89 | tracing = ["dep:tracing"] 90 | 91 | # internal features used in CI 92 | __internal_happy_eyeballs_tests = [] 93 | 94 | [[example]] 95 | name = "client" 96 | required-features = ["client-legacy", "http1", "tokio"] 97 | 98 | [[example]] 99 | name = "server" 100 | required-features = ["server", "http1", "tokio"] 101 | 102 | [[example]] 103 | name = "server_graceful" 104 | required-features = ["tokio", "server-graceful", "server-auto"] 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023-2025 Sean McArthur 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hyper-util 2 | 3 | [![crates.io](https://img.shields.io/crates/v/hyper-util.svg)](https://crates.io/crates/hyper-util) 4 | [![Released API docs](https://docs.rs/hyper-util/badge.svg)](https://docs.rs/hyper-util) 5 | [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE) 6 | 7 | A collection of utilities to do common things with [hyper](https://hyper.rs). 8 | 9 | ## License 10 | 11 | This project is licensed under the [MIT license](./LICENSE). 12 | -------------------------------------------------------------------------------- /examples/client.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use http_body_util::Empty; 4 | use hyper::Request; 5 | use hyper_util::client::legacy::{connect::HttpConnector, Client}; 6 | 7 | #[tokio::main(flavor = "current_thread")] 8 | async fn main() -> Result<(), Box> { 9 | let url = match env::args().nth(1) { 10 | Some(url) => url, 11 | None => { 12 | eprintln!("Usage: client "); 13 | return Ok(()); 14 | } 15 | }; 16 | 17 | // HTTPS requires picking a TLS implementation, so give a better 18 | // warning if the user tries to request an 'https' URL. 19 | let url = url.parse::()?; 20 | if url.scheme_str() != Some("http") { 21 | eprintln!("This example only works with 'http' URLs."); 22 | return Ok(()); 23 | } 24 | 25 | let client = Client::builder(hyper_util::rt::TokioExecutor::new()).build(HttpConnector::new()); 26 | 27 | let req = Request::builder() 28 | .uri(url) 29 | .body(Empty::::new())?; 30 | 31 | let resp = client.request(req).await?; 32 | 33 | eprintln!("{:?} {:?}", resp.version(), resp.status()); 34 | eprintln!("{:#?}", resp.headers()); 35 | 36 | Ok(()) 37 | } 38 | -------------------------------------------------------------------------------- /examples/server.rs: -------------------------------------------------------------------------------- 1 | //! This example runs a server that responds to any request with "Hello, world!" 2 | 3 | use std::{convert::Infallible, error::Error}; 4 | 5 | use bytes::Bytes; 6 | use http::{header::CONTENT_TYPE, Request, Response}; 7 | use http_body_util::{combinators::BoxBody, BodyExt, Full}; 8 | use hyper::{body::Incoming, service::service_fn}; 9 | use hyper_util::{ 10 | rt::{TokioExecutor, TokioIo}, 11 | server::conn::auto::Builder, 12 | }; 13 | use tokio::{net::TcpListener, task::JoinSet}; 14 | 15 | /// Function from an incoming request to an outgoing response 16 | /// 17 | /// This function gets turned into a [`hyper::service::Service`] later via 18 | /// [`service_fn`]. Instead of doing this, you could also write a type that 19 | /// implements [`hyper::service::Service`] directly and pass that in place of 20 | /// writing a function like this and calling [`service_fn`]. 21 | /// 22 | /// This function could use [`Full`] as the body type directly since that's 23 | /// the only type that can be returned in this case, but this uses [`BoxBody`] 24 | /// anyway for demonstration purposes, since this is what's usually used when 25 | /// writing a more complex webserver library. 26 | async fn handle_request( 27 | _request: Request, 28 | ) -> Result>, Infallible> { 29 | let response = Response::builder() 30 | .header(CONTENT_TYPE, "text/plain") 31 | .body(Full::new(Bytes::from("Hello, world!\n")).boxed()) 32 | .expect("values provided to the builder should be valid"); 33 | 34 | Ok(response) 35 | } 36 | 37 | #[tokio::main(flavor = "current_thread")] 38 | async fn main() -> Result<(), Box> { 39 | let listen_addr = "127.0.0.1:8000"; 40 | let tcp_listener = TcpListener::bind(listen_addr).await?; 41 | println!("listening on http://{listen_addr}"); 42 | 43 | let mut join_set = JoinSet::new(); 44 | loop { 45 | let (stream, addr) = match tcp_listener.accept().await { 46 | Ok(x) => x, 47 | Err(e) => { 48 | eprintln!("failed to accept connection: {e}"); 49 | continue; 50 | } 51 | }; 52 | 53 | let serve_connection = async move { 54 | println!("handling a request from {addr}"); 55 | 56 | let result = Builder::new(TokioExecutor::new()) 57 | .serve_connection(TokioIo::new(stream), service_fn(handle_request)) 58 | .await; 59 | 60 | if let Err(e) = result { 61 | eprintln!("error serving {addr}: {e}"); 62 | } 63 | 64 | println!("handled a request from {addr}"); 65 | }; 66 | 67 | join_set.spawn(serve_connection); 68 | } 69 | 70 | // If you add a method for breaking the above loop (i.e. graceful shutdown), 71 | // then you may also want to wait for all existing connections to finish 72 | // being served before terminating the program, which can be done like this: 73 | // 74 | // while let Some(_) = join_set.join_next().await {} 75 | } 76 | -------------------------------------------------------------------------------- /examples/server_graceful.rs: -------------------------------------------------------------------------------- 1 | use bytes::Bytes; 2 | use std::convert::Infallible; 3 | use std::pin::pin; 4 | use std::time::Duration; 5 | use tokio::net::TcpListener; 6 | 7 | #[tokio::main(flavor = "current_thread")] 8 | async fn main() -> Result<(), Box> { 9 | let listener = TcpListener::bind("127.0.0.1:8080").await?; 10 | 11 | let server = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); 12 | let graceful = hyper_util::server::graceful::GracefulShutdown::new(); 13 | let mut ctrl_c = pin!(tokio::signal::ctrl_c()); 14 | 15 | loop { 16 | tokio::select! { 17 | conn = listener.accept() => { 18 | let (stream, peer_addr) = match conn { 19 | Ok(conn) => conn, 20 | Err(e) => { 21 | eprintln!("accept error: {}", e); 22 | tokio::time::sleep(Duration::from_secs(1)).await; 23 | continue; 24 | } 25 | }; 26 | eprintln!("incomming connection accepted: {}", peer_addr); 27 | 28 | let stream = hyper_util::rt::TokioIo::new(Box::pin(stream)); 29 | 30 | let conn = server.serve_connection_with_upgrades(stream, hyper::service::service_fn(|_| async move { 31 | tokio::time::sleep(Duration::from_secs(5)).await; // emulate slow request 32 | let body = http_body_util::Full::::from("Hello World!".to_owned()); 33 | Ok::<_, Infallible>(http::Response::new(body)) 34 | })); 35 | 36 | let conn = graceful.watch(conn.into_owned()); 37 | 38 | tokio::spawn(async move { 39 | if let Err(err) = conn.await { 40 | eprintln!("connection error: {}", err); 41 | } 42 | eprintln!("connection dropped: {}", peer_addr); 43 | }); 44 | }, 45 | 46 | _ = ctrl_c.as_mut() => { 47 | drop(listener); 48 | eprintln!("Ctrl-C received, starting shutdown"); 49 | break; 50 | } 51 | } 52 | } 53 | 54 | tokio::select! { 55 | _ = graceful.shutdown() => { 56 | eprintln!("Gracefully shutdown!"); 57 | }, 58 | _ = tokio::time::sleep(Duration::from_secs(10)) => { 59 | eprintln!("Waited 10 seconds for graceful shutdown, aborting..."); 60 | } 61 | } 62 | 63 | Ok(()) 64 | } 65 | -------------------------------------------------------------------------------- /src/client/client.rs: -------------------------------------------------------------------------------- 1 | use hyper::{Request, Response}; 2 | use tower::{Service, MakeService}; 3 | 4 | use super::connect::Connect; 5 | use super::pool; 6 | 7 | pub struct Client { 8 | // Hi there. So, let's take a 0.14.x hyper::Client, and build up its layers 9 | // here. We don't need to fully expose the layers to start with, but that 10 | // is the end goal. 11 | // 12 | // Client = MakeSvcAsService< 13 | // SetHost< 14 | // Http1RequestTarget< 15 | // DelayedRelease< 16 | // ConnectingPool 17 | // > 18 | // > 19 | // > 20 | // > 21 | make_svc: M, 22 | } 23 | 24 | // We might change this... :shrug: 25 | type PoolKey = hyper::Uri; 26 | 27 | struct ConnectingPool { 28 | connector: C, 29 | pool: P, 30 | } 31 | 32 | struct PoolableSvc(S); 33 | 34 | /// A marker to identify what version a pooled connection is. 35 | #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] 36 | #[allow(dead_code)] 37 | pub enum Ver { 38 | Auto, 39 | Http2, 40 | } 41 | 42 | // ===== impl Client ===== 43 | 44 | impl Client 45 | where 46 | M: MakeService< 47 | hyper::Uri, 48 | Request<()>, 49 | Response = Response<()>, 50 | Error = E, 51 | MakeError = E, 52 | >, 53 | //M: Service, 54 | //M::Response: Service, Response = Response>, 55 | { 56 | pub async fn request(&mut self, req: Request<()>) -> Result, E> { 57 | let mut svc = self.make_svc.make_service(req.uri().clone()).await?; 58 | svc.call(req).await 59 | } 60 | } 61 | 62 | impl Client 63 | where 64 | M: MakeService< 65 | hyper::Uri, 66 | Request<()>, 67 | Response = Response<()>, 68 | Error = E, 69 | MakeError = E, 70 | >, 71 | //M: Service, 72 | //M::Response: Service, Response = Response>, 73 | { 74 | 75 | } 76 | 77 | // ===== impl ConnectingPool ===== 78 | 79 | impl ConnectingPool 80 | where 81 | C: Connect, 82 | C::_Svc: Unpin + Send + 'static, 83 | { 84 | async fn connection_for(&self, target: PoolKey) -> Result, PoolKey>, ()> { 85 | todo!() 86 | } 87 | } 88 | 89 | impl pool::Poolable for PoolableSvc 90 | where 91 | S: Unpin + Send + 'static, 92 | { 93 | fn is_open(&self) -> bool { 94 | /* 95 | match self.tx { 96 | PoolTx::Http1(ref tx) => tx.is_ready(), 97 | #[cfg(feature = "http2")] 98 | PoolTx::Http2(ref tx) => tx.is_ready(), 99 | } 100 | */ 101 | true 102 | } 103 | 104 | fn reserve(self) -> pool::Reservation { 105 | /* 106 | match self.tx { 107 | PoolTx::Http1(tx) => Reservation::Unique(PoolClient { 108 | conn_info: self.conn_info, 109 | tx: PoolTx::Http1(tx), 110 | }), 111 | #[cfg(feature = "http2")] 112 | PoolTx::Http2(tx) => { 113 | let b = PoolClient { 114 | conn_info: self.conn_info.clone(), 115 | tx: PoolTx::Http2(tx.clone()), 116 | }; 117 | let a = PoolClient { 118 | conn_info: self.conn_info, 119 | tx: PoolTx::Http2(tx), 120 | }; 121 | Reservation::Shared(a, b) 122 | } 123 | } 124 | */ 125 | pool::Reservation::Unique(self) 126 | } 127 | 128 | fn can_share(&self) -> bool { 129 | false 130 | //self.is_http2() 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /src/client/legacy/connect/capture.rs: -------------------------------------------------------------------------------- 1 | use std::{ops::Deref, sync::Arc}; 2 | 3 | use http::Request; 4 | use tokio::sync::watch; 5 | 6 | use super::Connected; 7 | 8 | /// [`CaptureConnection`] allows callers to capture [`Connected`] information 9 | /// 10 | /// To capture a connection for a request, use [`capture_connection`]. 11 | #[derive(Debug, Clone)] 12 | pub struct CaptureConnection { 13 | rx: watch::Receiver>, 14 | } 15 | 16 | /// Capture the connection for a given request 17 | /// 18 | /// When making a request with Hyper, the underlying connection must implement the [`Connection`] trait. 19 | /// [`capture_connection`] allows a caller to capture the returned [`Connected`] structure as soon 20 | /// as the connection is established. 21 | /// 22 | /// [`Connection`]: crate::client::legacy::connect::Connection 23 | /// 24 | /// *Note*: If establishing a connection fails, [`CaptureConnection::connection_metadata`] will always return none. 25 | /// 26 | /// # Examples 27 | /// 28 | /// **Synchronous access**: 29 | /// The [`CaptureConnection::connection_metadata`] method allows callers to check if a connection has been 30 | /// established. This is ideal for situations where you are certain the connection has already 31 | /// been established (e.g. after the response future has already completed). 32 | /// ```rust 33 | /// use hyper_util::client::legacy::connect::capture_connection; 34 | /// let mut request = http::Request::builder() 35 | /// .uri("http://foo.com") 36 | /// .body(()) 37 | /// .unwrap(); 38 | /// 39 | /// let captured_connection = capture_connection(&mut request); 40 | /// // some time later after the request has been sent... 41 | /// let connection_info = captured_connection.connection_metadata(); 42 | /// println!("we are connected! {:?}", connection_info.as_ref()); 43 | /// ``` 44 | /// 45 | /// **Asynchronous access**: 46 | /// The [`CaptureConnection::wait_for_connection_metadata`] method returns a future resolves as soon as the 47 | /// connection is available. 48 | /// 49 | /// ```rust 50 | /// # #[cfg(feature = "tokio")] 51 | /// # async fn example() { 52 | /// use hyper_util::client::legacy::connect::capture_connection; 53 | /// use hyper_util::client::legacy::Client; 54 | /// use hyper_util::rt::TokioExecutor; 55 | /// use bytes::Bytes; 56 | /// use http_body_util::Empty; 57 | /// let mut request = http::Request::builder() 58 | /// .uri("http://foo.com") 59 | /// .body(Empty::::new()) 60 | /// .unwrap(); 61 | /// 62 | /// let mut captured = capture_connection(&mut request); 63 | /// tokio::task::spawn(async move { 64 | /// let connection_info = captured.wait_for_connection_metadata().await; 65 | /// println!("we are connected! {:?}", connection_info.as_ref()); 66 | /// }); 67 | /// 68 | /// let client = Client::builder(TokioExecutor::new()).build_http(); 69 | /// client.request(request).await.expect("request failed"); 70 | /// # } 71 | /// ``` 72 | pub fn capture_connection(request: &mut Request) -> CaptureConnection { 73 | let (tx, rx) = CaptureConnection::new(); 74 | request.extensions_mut().insert(tx); 75 | rx 76 | } 77 | 78 | /// TxSide for [`CaptureConnection`] 79 | /// 80 | /// This is inserted into `Extensions` to allow Hyper to back channel connection info 81 | #[derive(Clone)] 82 | pub(crate) struct CaptureConnectionExtension { 83 | tx: Arc>>, 84 | } 85 | 86 | impl CaptureConnectionExtension { 87 | pub(crate) fn set(&self, connected: &Connected) { 88 | self.tx.send_replace(Some(connected.clone())); 89 | } 90 | } 91 | 92 | impl CaptureConnection { 93 | /// Internal API to create the tx and rx half of [`CaptureConnection`] 94 | pub(crate) fn new() -> (CaptureConnectionExtension, Self) { 95 | let (tx, rx) = watch::channel(None); 96 | ( 97 | CaptureConnectionExtension { tx: Arc::new(tx) }, 98 | CaptureConnection { rx }, 99 | ) 100 | } 101 | 102 | /// Retrieve the connection metadata, if available 103 | pub fn connection_metadata(&self) -> impl Deref> + '_ { 104 | self.rx.borrow() 105 | } 106 | 107 | /// Wait for the connection to be established 108 | /// 109 | /// If a connection was established, this will always return `Some(...)`. If the request never 110 | /// successfully connected (e.g. DNS resolution failure), this method will never return. 111 | pub async fn wait_for_connection_metadata( 112 | &mut self, 113 | ) -> impl Deref> + '_ { 114 | if self.rx.borrow().is_some() { 115 | return self.rx.borrow(); 116 | } 117 | let _ = self.rx.changed().await; 118 | self.rx.borrow() 119 | } 120 | } 121 | 122 | #[cfg(all(test, not(miri)))] 123 | mod test { 124 | use super::*; 125 | 126 | #[test] 127 | fn test_sync_capture_connection() { 128 | let (tx, rx) = CaptureConnection::new(); 129 | assert!( 130 | rx.connection_metadata().is_none(), 131 | "connection has not been set" 132 | ); 133 | tx.set(&Connected::new().proxy(true)); 134 | assert!(rx 135 | .connection_metadata() 136 | .as_ref() 137 | .expect("connected should be set") 138 | .is_proxied()); 139 | 140 | // ensure it can be called multiple times 141 | assert!(rx 142 | .connection_metadata() 143 | .as_ref() 144 | .expect("connected should be set") 145 | .is_proxied()); 146 | } 147 | 148 | #[tokio::test] 149 | async fn async_capture_connection() { 150 | let (tx, mut rx) = CaptureConnection::new(); 151 | assert!( 152 | rx.connection_metadata().is_none(), 153 | "connection has not been set" 154 | ); 155 | let test_task = tokio::spawn(async move { 156 | assert!(rx 157 | .wait_for_connection_metadata() 158 | .await 159 | .as_ref() 160 | .expect("connection should be set") 161 | .is_proxied()); 162 | // can be awaited multiple times 163 | assert!( 164 | rx.wait_for_connection_metadata().await.is_some(), 165 | "should be awaitable multiple times" 166 | ); 167 | 168 | assert!(rx.connection_metadata().is_some()); 169 | }); 170 | // can't be finished, we haven't set the connection yet 171 | assert!(!test_task.is_finished()); 172 | tx.set(&Connected::new().proxy(true)); 173 | 174 | assert!(test_task.await.is_ok()); 175 | } 176 | 177 | #[tokio::test] 178 | async fn capture_connection_sender_side_dropped() { 179 | let (tx, mut rx) = CaptureConnection::new(); 180 | assert!( 181 | rx.connection_metadata().is_none(), 182 | "connection has not been set" 183 | ); 184 | drop(tx); 185 | assert!(rx.wait_for_connection_metadata().await.is_none()); 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /src/client/legacy/connect/dns.rs: -------------------------------------------------------------------------------- 1 | //! DNS Resolution used by the `HttpConnector`. 2 | //! 3 | //! This module contains: 4 | //! 5 | //! - A [`GaiResolver`] that is the default resolver for the `HttpConnector`. 6 | //! - The `Name` type used as an argument to custom resolvers. 7 | //! 8 | //! # Resolvers are `Service`s 9 | //! 10 | //! A resolver is just a 11 | //! `Service>`. 12 | //! 13 | //! A simple resolver that ignores the name and always returns a specific 14 | //! address: 15 | //! 16 | //! ```rust,ignore 17 | //! use std::{convert::Infallible, iter, net::SocketAddr}; 18 | //! 19 | //! let resolver = tower::service_fn(|_name| async { 20 | //! Ok::<_, Infallible>(iter::once(SocketAddr::from(([127, 0, 0, 1], 8080)))) 21 | //! }); 22 | //! ``` 23 | use std::error::Error; 24 | use std::future::Future; 25 | use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; 26 | use std::pin::Pin; 27 | use std::str::FromStr; 28 | use std::task::{self, Poll}; 29 | use std::{fmt, io, vec}; 30 | 31 | use tokio::task::JoinHandle; 32 | use tower_service::Service; 33 | 34 | pub(super) use self::sealed::Resolve; 35 | 36 | /// A domain name to resolve into IP addresses. 37 | #[derive(Clone, Hash, Eq, PartialEq)] 38 | pub struct Name { 39 | host: Box, 40 | } 41 | 42 | /// A resolver using blocking `getaddrinfo` calls in a threadpool. 43 | #[derive(Clone)] 44 | pub struct GaiResolver { 45 | _priv: (), 46 | } 47 | 48 | /// An iterator of IP addresses returned from `getaddrinfo`. 49 | pub struct GaiAddrs { 50 | inner: SocketAddrs, 51 | } 52 | 53 | /// A future to resolve a name returned by `GaiResolver`. 54 | pub struct GaiFuture { 55 | inner: JoinHandle>, 56 | } 57 | 58 | impl Name { 59 | pub(super) fn new(host: Box) -> Name { 60 | Name { host } 61 | } 62 | 63 | /// View the hostname as a string slice. 64 | pub fn as_str(&self) -> &str { 65 | &self.host 66 | } 67 | } 68 | 69 | impl fmt::Debug for Name { 70 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 71 | fmt::Debug::fmt(&self.host, f) 72 | } 73 | } 74 | 75 | impl fmt::Display for Name { 76 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 77 | fmt::Display::fmt(&self.host, f) 78 | } 79 | } 80 | 81 | impl FromStr for Name { 82 | type Err = InvalidNameError; 83 | 84 | fn from_str(host: &str) -> Result { 85 | // Possibly add validation later 86 | Ok(Name::new(host.into())) 87 | } 88 | } 89 | 90 | /// Error indicating a given string was not a valid domain name. 91 | #[derive(Debug)] 92 | pub struct InvalidNameError(()); 93 | 94 | impl fmt::Display for InvalidNameError { 95 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 96 | f.write_str("Not a valid domain name") 97 | } 98 | } 99 | 100 | impl Error for InvalidNameError {} 101 | 102 | impl GaiResolver { 103 | /// Construct a new `GaiResolver`. 104 | pub fn new() -> Self { 105 | GaiResolver { _priv: () } 106 | } 107 | } 108 | 109 | impl Service for GaiResolver { 110 | type Response = GaiAddrs; 111 | type Error = io::Error; 112 | type Future = GaiFuture; 113 | 114 | fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { 115 | Poll::Ready(Ok(())) 116 | } 117 | 118 | fn call(&mut self, name: Name) -> Self::Future { 119 | let blocking = tokio::task::spawn_blocking(move || { 120 | (&*name.host, 0) 121 | .to_socket_addrs() 122 | .map(|i| SocketAddrs { iter: i }) 123 | }); 124 | 125 | GaiFuture { inner: blocking } 126 | } 127 | } 128 | 129 | impl fmt::Debug for GaiResolver { 130 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 131 | f.pad("GaiResolver") 132 | } 133 | } 134 | 135 | impl Future for GaiFuture { 136 | type Output = Result; 137 | 138 | fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { 139 | Pin::new(&mut self.inner).poll(cx).map(|res| match res { 140 | Ok(Ok(addrs)) => Ok(GaiAddrs { inner: addrs }), 141 | Ok(Err(err)) => Err(err), 142 | Err(join_err) => { 143 | if join_err.is_cancelled() { 144 | Err(io::Error::new(io::ErrorKind::Interrupted, join_err)) 145 | } else { 146 | panic!("gai background task failed: {join_err:?}") 147 | } 148 | } 149 | }) 150 | } 151 | } 152 | 153 | impl fmt::Debug for GaiFuture { 154 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 155 | f.pad("GaiFuture") 156 | } 157 | } 158 | 159 | impl Drop for GaiFuture { 160 | fn drop(&mut self) { 161 | self.inner.abort(); 162 | } 163 | } 164 | 165 | impl Iterator for GaiAddrs { 166 | type Item = SocketAddr; 167 | 168 | fn next(&mut self) -> Option { 169 | self.inner.next() 170 | } 171 | } 172 | 173 | impl fmt::Debug for GaiAddrs { 174 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 175 | f.pad("GaiAddrs") 176 | } 177 | } 178 | 179 | pub(super) struct SocketAddrs { 180 | iter: vec::IntoIter, 181 | } 182 | 183 | impl SocketAddrs { 184 | pub(super) fn new(addrs: Vec) -> Self { 185 | SocketAddrs { 186 | iter: addrs.into_iter(), 187 | } 188 | } 189 | 190 | pub(super) fn try_parse(host: &str, port: u16) -> Option { 191 | if let Ok(addr) = host.parse::() { 192 | let addr = SocketAddrV4::new(addr, port); 193 | return Some(SocketAddrs { 194 | iter: vec![SocketAddr::V4(addr)].into_iter(), 195 | }); 196 | } 197 | if let Ok(addr) = host.parse::() { 198 | let addr = SocketAddrV6::new(addr, port, 0, 0); 199 | return Some(SocketAddrs { 200 | iter: vec![SocketAddr::V6(addr)].into_iter(), 201 | }); 202 | } 203 | None 204 | } 205 | 206 | #[inline] 207 | fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> SocketAddrs { 208 | SocketAddrs::new(self.iter.filter(predicate).collect()) 209 | } 210 | 211 | pub(super) fn split_by_preference( 212 | self, 213 | local_addr_ipv4: Option, 214 | local_addr_ipv6: Option, 215 | ) -> (SocketAddrs, SocketAddrs) { 216 | match (local_addr_ipv4, local_addr_ipv6) { 217 | // Filter out based on what the local addr can use 218 | (Some(_), None) => (self.filter(SocketAddr::is_ipv4), SocketAddrs::new(vec![])), 219 | (None, Some(_)) => (self.filter(SocketAddr::is_ipv6), SocketAddrs::new(vec![])), 220 | _ => { 221 | // Happy Eyeballs says we always give a preference to v6 if available 222 | let (preferred, fallback) = self.iter.partition::, _>(SocketAddr::is_ipv6); 223 | 224 | (SocketAddrs::new(preferred), SocketAddrs::new(fallback)) 225 | } 226 | } 227 | } 228 | 229 | pub(super) fn is_empty(&self) -> bool { 230 | self.iter.as_slice().is_empty() 231 | } 232 | 233 | pub(super) fn len(&self) -> usize { 234 | self.iter.as_slice().len() 235 | } 236 | } 237 | 238 | impl Iterator for SocketAddrs { 239 | type Item = SocketAddr; 240 | #[inline] 241 | fn next(&mut self) -> Option { 242 | self.iter.next() 243 | } 244 | } 245 | 246 | mod sealed { 247 | use std::future::Future; 248 | use std::task::{self, Poll}; 249 | 250 | use super::{Name, SocketAddr}; 251 | use tower_service::Service; 252 | 253 | // "Trait alias" for `Service` 254 | pub trait Resolve { 255 | type Addrs: Iterator; 256 | type Error: Into>; 257 | type Future: Future>; 258 | 259 | fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll>; 260 | fn resolve(&mut self, name: Name) -> Self::Future; 261 | } 262 | 263 | impl Resolve for S 264 | where 265 | S: Service, 266 | S::Response: Iterator, 267 | S::Error: Into>, 268 | { 269 | type Addrs = S::Response; 270 | type Error = S::Error; 271 | type Future = S::Future; 272 | 273 | fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { 274 | Service::poll_ready(self, cx) 275 | } 276 | 277 | fn resolve(&mut self, name: Name) -> Self::Future { 278 | Service::call(self, name) 279 | } 280 | } 281 | } 282 | 283 | pub(super) async fn resolve(resolver: &mut R, name: Name) -> Result 284 | where 285 | R: Resolve, 286 | { 287 | crate::common::future::poll_fn(|cx| resolver.poll_ready(cx)).await?; 288 | resolver.resolve(name).await 289 | } 290 | 291 | #[cfg(test)] 292 | mod tests { 293 | use super::*; 294 | use std::net::{Ipv4Addr, Ipv6Addr}; 295 | 296 | #[test] 297 | fn test_ip_addrs_split_by_preference() { 298 | let ip_v4 = Ipv4Addr::new(127, 0, 0, 1); 299 | let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); 300 | let v4_addr = (ip_v4, 80).into(); 301 | let v6_addr = (ip_v6, 80).into(); 302 | 303 | // Even if ipv4 started first, prefer ipv6 304 | let (mut preferred, mut fallback) = SocketAddrs { 305 | iter: vec![v4_addr, v6_addr].into_iter(), 306 | } 307 | .split_by_preference(None, None); 308 | assert!(preferred.next().unwrap().is_ipv6()); 309 | assert!(fallback.next().unwrap().is_ipv4()); 310 | 311 | let (mut preferred, mut fallback) = SocketAddrs { 312 | iter: vec![v6_addr, v4_addr].into_iter(), 313 | } 314 | .split_by_preference(None, None); 315 | assert!(preferred.next().unwrap().is_ipv6()); 316 | assert!(fallback.next().unwrap().is_ipv4()); 317 | 318 | let (mut preferred, mut fallback) = SocketAddrs { 319 | iter: vec![v4_addr, v6_addr].into_iter(), 320 | } 321 | .split_by_preference(Some(ip_v4), Some(ip_v6)); 322 | assert!(preferred.next().unwrap().is_ipv6()); 323 | assert!(fallback.next().unwrap().is_ipv4()); 324 | 325 | let (mut preferred, mut fallback) = SocketAddrs { 326 | iter: vec![v6_addr, v4_addr].into_iter(), 327 | } 328 | .split_by_preference(Some(ip_v4), Some(ip_v6)); 329 | assert!(preferred.next().unwrap().is_ipv6()); 330 | assert!(fallback.next().unwrap().is_ipv4()); 331 | 332 | let (mut preferred, fallback) = SocketAddrs { 333 | iter: vec![v4_addr, v6_addr].into_iter(), 334 | } 335 | .split_by_preference(Some(ip_v4), None); 336 | assert!(preferred.next().unwrap().is_ipv4()); 337 | assert!(fallback.is_empty()); 338 | 339 | let (mut preferred, fallback) = SocketAddrs { 340 | iter: vec![v4_addr, v6_addr].into_iter(), 341 | } 342 | .split_by_preference(None, Some(ip_v6)); 343 | assert!(preferred.next().unwrap().is_ipv6()); 344 | assert!(fallback.is_empty()); 345 | } 346 | 347 | #[test] 348 | fn test_name_from_str() { 349 | const DOMAIN: &str = "test.example.com"; 350 | let name = Name::from_str(DOMAIN).expect("Should be a valid domain"); 351 | assert_eq!(name.as_str(), DOMAIN); 352 | assert_eq!(name.to_string(), DOMAIN); 353 | } 354 | } 355 | -------------------------------------------------------------------------------- /src/client/legacy/connect/mod.rs: -------------------------------------------------------------------------------- 1 | //! Connectors used by the `Client`. 2 | //! 3 | //! This module contains: 4 | //! 5 | //! - A default [`HttpConnector`][] that does DNS resolution and establishes 6 | //! connections over TCP. 7 | //! - Types to build custom connectors. 8 | //! 9 | //! # Connectors 10 | //! 11 | //! A "connector" is a [`Service`][] that takes a [`Uri`][] destination, and 12 | //! its `Response` is some type implementing [`Read`][], [`Write`][], 13 | //! and [`Connection`][]. 14 | //! 15 | //! ## Custom Connectors 16 | //! 17 | //! A simple connector that ignores the `Uri` destination and always returns 18 | //! a TCP connection to the same address could be written like this: 19 | //! 20 | //! ```rust,ignore 21 | //! let connector = tower::service_fn(|_dst| async { 22 | //! tokio::net::TcpStream::connect("127.0.0.1:1337") 23 | //! }) 24 | //! ``` 25 | //! 26 | //! Or, fully written out: 27 | //! 28 | //! ``` 29 | //! use std::{future::Future, net::SocketAddr, pin::Pin, task::{self, Poll}}; 30 | //! use http::Uri; 31 | //! use tokio::net::TcpStream; 32 | //! use tower_service::Service; 33 | //! 34 | //! #[derive(Clone)] 35 | //! struct LocalConnector; 36 | //! 37 | //! impl Service for LocalConnector { 38 | //! type Response = TcpStream; 39 | //! type Error = std::io::Error; 40 | //! // We can't "name" an `async` generated future. 41 | //! type Future = Pin> + Send 43 | //! >>; 44 | //! 45 | //! fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll> { 46 | //! // This connector is always ready, but others might not be. 47 | //! Poll::Ready(Ok(())) 48 | //! } 49 | //! 50 | //! fn call(&mut self, _: Uri) -> Self::Future { 51 | //! Box::pin(TcpStream::connect(SocketAddr::from(([127, 0, 0, 1], 1337)))) 52 | //! } 53 | //! } 54 | //! ``` 55 | //! 56 | //! It's worth noting that for `TcpStream`s, the [`HttpConnector`][] is a 57 | //! better starting place to extend from. 58 | //! 59 | //! [`HttpConnector`]: HttpConnector 60 | //! [`Service`]: tower_service::Service 61 | //! [`Uri`]: ::http::Uri 62 | //! [`Read`]: hyper::rt::Read 63 | //! [`Write`]: hyper::rt::Write 64 | //! [`Connection`]: Connection 65 | use std::{ 66 | fmt::{self, Formatter}, 67 | sync::{ 68 | atomic::{AtomicBool, Ordering}, 69 | Arc, 70 | }, 71 | }; 72 | 73 | use ::http::Extensions; 74 | 75 | #[cfg(feature = "tokio")] 76 | pub use self::http::{HttpConnector, HttpInfo}; 77 | 78 | #[cfg(feature = "tokio")] 79 | pub mod dns; 80 | #[cfg(feature = "tokio")] 81 | mod http; 82 | 83 | pub mod proxy; 84 | 85 | pub(crate) mod capture; 86 | pub use capture::{capture_connection, CaptureConnection}; 87 | 88 | pub use self::sealed::Connect; 89 | 90 | /// Describes a type returned by a connector. 91 | pub trait Connection { 92 | /// Return metadata describing the connection. 93 | fn connected(&self) -> Connected; 94 | } 95 | 96 | /// Extra information about the connected transport. 97 | /// 98 | /// This can be used to inform recipients about things like if ALPN 99 | /// was used, or if connected to an HTTP proxy. 100 | #[derive(Debug)] 101 | pub struct Connected { 102 | pub(super) alpn: Alpn, 103 | pub(super) is_proxied: bool, 104 | pub(super) extra: Option, 105 | pub(super) poisoned: PoisonPill, 106 | } 107 | 108 | #[derive(Clone)] 109 | pub(crate) struct PoisonPill { 110 | poisoned: Arc, 111 | } 112 | 113 | impl fmt::Debug for PoisonPill { 114 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 115 | // print the address of the pill—this makes debugging issues much easier 116 | write!( 117 | f, 118 | "PoisonPill@{:p} {{ poisoned: {} }}", 119 | self.poisoned, 120 | self.poisoned.load(Ordering::Relaxed) 121 | ) 122 | } 123 | } 124 | 125 | impl PoisonPill { 126 | pub(crate) fn healthy() -> Self { 127 | Self { 128 | poisoned: Arc::new(AtomicBool::new(false)), 129 | } 130 | } 131 | pub(crate) fn poison(&self) { 132 | self.poisoned.store(true, Ordering::Relaxed) 133 | } 134 | 135 | pub(crate) fn poisoned(&self) -> bool { 136 | self.poisoned.load(Ordering::Relaxed) 137 | } 138 | } 139 | 140 | pub(super) struct Extra(Box); 141 | 142 | #[derive(Clone, Copy, Debug, PartialEq)] 143 | pub(super) enum Alpn { 144 | H2, 145 | None, 146 | } 147 | 148 | impl Connected { 149 | /// Create new `Connected` type with empty metadata. 150 | pub fn new() -> Connected { 151 | Connected { 152 | alpn: Alpn::None, 153 | is_proxied: false, 154 | extra: None, 155 | poisoned: PoisonPill::healthy(), 156 | } 157 | } 158 | 159 | /// Set whether the connected transport is to an HTTP proxy. 160 | /// 161 | /// This setting will affect if HTTP/1 requests written on the transport 162 | /// will have the request-target in absolute-form or origin-form: 163 | /// 164 | /// - When `proxy(false)`: 165 | /// 166 | /// ```http 167 | /// GET /guide HTTP/1.1 168 | /// ``` 169 | /// 170 | /// - When `proxy(true)`: 171 | /// 172 | /// ```http 173 | /// GET http://hyper.rs/guide HTTP/1.1 174 | /// ``` 175 | /// 176 | /// Default is `false`. 177 | pub fn proxy(mut self, is_proxied: bool) -> Connected { 178 | self.is_proxied = is_proxied; 179 | self 180 | } 181 | 182 | /// Determines if the connected transport is to an HTTP proxy. 183 | pub fn is_proxied(&self) -> bool { 184 | self.is_proxied 185 | } 186 | 187 | /// Set extra connection information to be set in the extensions of every `Response`. 188 | pub fn extra(mut self, extra: T) -> Connected { 189 | if let Some(prev) = self.extra { 190 | self.extra = Some(Extra(Box::new(ExtraChain(prev.0, extra)))); 191 | } else { 192 | self.extra = Some(Extra(Box::new(ExtraEnvelope(extra)))); 193 | } 194 | self 195 | } 196 | 197 | /// Copies the extra connection information into an `Extensions` map. 198 | pub fn get_extras(&self, extensions: &mut Extensions) { 199 | if let Some(extra) = &self.extra { 200 | extra.set(extensions); 201 | } 202 | } 203 | 204 | /// Set that the connected transport negotiated HTTP/2 as its next protocol. 205 | pub fn negotiated_h2(mut self) -> Connected { 206 | self.alpn = Alpn::H2; 207 | self 208 | } 209 | 210 | /// Determines if the connected transport negotiated HTTP/2 as its next protocol. 211 | pub fn is_negotiated_h2(&self) -> bool { 212 | self.alpn == Alpn::H2 213 | } 214 | 215 | /// Poison this connection 216 | /// 217 | /// A poisoned connection will not be reused for subsequent requests by the pool 218 | pub fn poison(&self) { 219 | self.poisoned.poison(); 220 | tracing::debug!( 221 | poison_pill = ?self.poisoned, "connection was poisoned. this connection will not be reused for subsequent requests" 222 | ); 223 | } 224 | 225 | // Don't public expose that `Connected` is `Clone`, unsure if we want to 226 | // keep that contract... 227 | pub(super) fn clone(&self) -> Connected { 228 | Connected { 229 | alpn: self.alpn, 230 | is_proxied: self.is_proxied, 231 | extra: self.extra.clone(), 232 | poisoned: self.poisoned.clone(), 233 | } 234 | } 235 | } 236 | 237 | // ===== impl Extra ===== 238 | 239 | impl Extra { 240 | pub(super) fn set(&self, res: &mut Extensions) { 241 | self.0.set(res); 242 | } 243 | } 244 | 245 | impl Clone for Extra { 246 | fn clone(&self) -> Extra { 247 | Extra(self.0.clone_box()) 248 | } 249 | } 250 | 251 | impl fmt::Debug for Extra { 252 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 253 | f.debug_struct("Extra").finish() 254 | } 255 | } 256 | 257 | trait ExtraInner: Send + Sync { 258 | fn clone_box(&self) -> Box; 259 | fn set(&self, res: &mut Extensions); 260 | } 261 | 262 | // This indirection allows the `Connected` to have a type-erased "extra" value, 263 | // while that type still knows its inner extra type. This allows the correct 264 | // TypeId to be used when inserting into `res.extensions_mut()`. 265 | #[derive(Clone)] 266 | struct ExtraEnvelope(T); 267 | 268 | impl ExtraInner for ExtraEnvelope 269 | where 270 | T: Clone + Send + Sync + 'static, 271 | { 272 | fn clone_box(&self) -> Box { 273 | Box::new(self.clone()) 274 | } 275 | 276 | fn set(&self, res: &mut Extensions) { 277 | res.insert(self.0.clone()); 278 | } 279 | } 280 | 281 | struct ExtraChain(Box, T); 282 | 283 | impl Clone for ExtraChain { 284 | fn clone(&self) -> Self { 285 | ExtraChain(self.0.clone_box(), self.1.clone()) 286 | } 287 | } 288 | 289 | impl ExtraInner for ExtraChain 290 | where 291 | T: Clone + Send + Sync + 'static, 292 | { 293 | fn clone_box(&self) -> Box { 294 | Box::new(self.clone()) 295 | } 296 | 297 | fn set(&self, res: &mut Extensions) { 298 | self.0.set(res); 299 | res.insert(self.1.clone()); 300 | } 301 | } 302 | 303 | pub(super) mod sealed { 304 | use std::error::Error as StdError; 305 | use std::future::Future; 306 | 307 | use ::http::Uri; 308 | use hyper::rt::{Read, Write}; 309 | 310 | use super::Connection; 311 | 312 | /// Connect to a destination, returning an IO transport. 313 | /// 314 | /// A connector receives a [`Uri`](::http::Uri) and returns a `Future` of the 315 | /// ready connection. 316 | /// 317 | /// # Trait Alias 318 | /// 319 | /// This is really just an *alias* for the `tower::Service` trait, with 320 | /// additional bounds set for convenience *inside* hyper. You don't actually 321 | /// implement this trait, but `tower::Service` instead. 322 | // The `Sized` bound is to prevent creating `dyn Connect`, since they cannot 323 | // fit the `Connect` bounds because of the blanket impl for `Service`. 324 | pub trait Connect: Sealed + Sized { 325 | #[doc(hidden)] 326 | type _Svc: ConnectSvc; 327 | #[doc(hidden)] 328 | fn connect(self, internal_only: Internal, dst: Uri) -> ::Future; 329 | } 330 | 331 | pub trait ConnectSvc { 332 | type Connection: Read + Write + Connection + Unpin + Send + 'static; 333 | type Error: Into>; 334 | type Future: Future> + Unpin + Send + 'static; 335 | 336 | fn connect(self, internal_only: Internal, dst: Uri) -> Self::Future; 337 | } 338 | 339 | impl Connect for S 340 | where 341 | S: tower_service::Service + Send + 'static, 342 | S::Error: Into>, 343 | S::Future: Unpin + Send, 344 | T: Read + Write + Connection + Unpin + Send + 'static, 345 | { 346 | type _Svc = S; 347 | 348 | fn connect(self, _: Internal, dst: Uri) -> crate::service::Oneshot { 349 | crate::service::Oneshot::new(self, dst) 350 | } 351 | } 352 | 353 | impl ConnectSvc for S 354 | where 355 | S: tower_service::Service + Send + 'static, 356 | S::Error: Into>, 357 | S::Future: Unpin + Send, 358 | T: Read + Write + Connection + Unpin + Send + 'static, 359 | { 360 | type Connection = T; 361 | type Error = S::Error; 362 | type Future = crate::service::Oneshot; 363 | 364 | fn connect(self, _: Internal, dst: Uri) -> Self::Future { 365 | crate::service::Oneshot::new(self, dst) 366 | } 367 | } 368 | 369 | impl Sealed for S 370 | where 371 | S: tower_service::Service + Send, 372 | S::Error: Into>, 373 | S::Future: Unpin + Send, 374 | T: Read + Write + Connection + Unpin + Send + 'static, 375 | { 376 | } 377 | 378 | pub trait Sealed {} 379 | #[allow(missing_debug_implementations)] 380 | pub struct Internal; 381 | } 382 | 383 | #[cfg(test)] 384 | mod tests { 385 | use super::Connected; 386 | 387 | #[derive(Clone, Debug, PartialEq)] 388 | struct Ex1(usize); 389 | 390 | #[derive(Clone, Debug, PartialEq)] 391 | struct Ex2(&'static str); 392 | 393 | #[derive(Clone, Debug, PartialEq)] 394 | struct Ex3(&'static str); 395 | 396 | #[test] 397 | fn test_connected_extra() { 398 | let c1 = Connected::new().extra(Ex1(41)); 399 | 400 | let mut ex = ::http::Extensions::new(); 401 | 402 | assert_eq!(ex.get::(), None); 403 | 404 | c1.extra.as_ref().expect("c1 extra").set(&mut ex); 405 | 406 | assert_eq!(ex.get::(), Some(&Ex1(41))); 407 | } 408 | 409 | #[test] 410 | fn test_connected_extra_chain() { 411 | // If a user composes connectors and at each stage, there's "extra" 412 | // info to attach, it shouldn't override the previous extras. 413 | 414 | let c1 = Connected::new() 415 | .extra(Ex1(45)) 416 | .extra(Ex2("zoom")) 417 | .extra(Ex3("pew pew")); 418 | 419 | let mut ex1 = ::http::Extensions::new(); 420 | 421 | assert_eq!(ex1.get::(), None); 422 | assert_eq!(ex1.get::(), None); 423 | assert_eq!(ex1.get::(), None); 424 | 425 | c1.extra.as_ref().expect("c1 extra").set(&mut ex1); 426 | 427 | assert_eq!(ex1.get::(), Some(&Ex1(45))); 428 | assert_eq!(ex1.get::(), Some(&Ex2("zoom"))); 429 | assert_eq!(ex1.get::(), Some(&Ex3("pew pew"))); 430 | 431 | // Just like extensions, inserting the same type overrides previous type. 432 | let c2 = Connected::new() 433 | .extra(Ex1(33)) 434 | .extra(Ex2("hiccup")) 435 | .extra(Ex1(99)); 436 | 437 | let mut ex2 = ::http::Extensions::new(); 438 | 439 | c2.extra.as_ref().expect("c2 extra").set(&mut ex2); 440 | 441 | assert_eq!(ex2.get::(), Some(&Ex1(99))); 442 | assert_eq!(ex2.get::(), Some(&Ex2("hiccup"))); 443 | } 444 | } 445 | -------------------------------------------------------------------------------- /src/client/legacy/connect/proxy/mod.rs: -------------------------------------------------------------------------------- 1 | //! Proxy helpers 2 | mod socks; 3 | mod tunnel; 4 | 5 | pub use self::socks::{SocksV4, SocksV5}; 6 | pub use self::tunnel::Tunnel; 7 | -------------------------------------------------------------------------------- /src/client/legacy/connect/proxy/socks/mod.rs: -------------------------------------------------------------------------------- 1 | mod v5; 2 | pub use v5::{SocksV5, SocksV5Error}; 3 | 4 | mod v4; 5 | pub use v4::{SocksV4, SocksV4Error}; 6 | 7 | use bytes::BytesMut; 8 | 9 | use hyper::rt::Read; 10 | 11 | #[derive(Debug)] 12 | pub enum SocksError { 13 | Inner(C), 14 | Io(std::io::Error), 15 | 16 | DnsFailure, 17 | MissingHost, 18 | MissingPort, 19 | 20 | V4(SocksV4Error), 21 | V5(SocksV5Error), 22 | 23 | Parsing(ParsingError), 24 | Serialize(SerializeError), 25 | } 26 | 27 | #[derive(Debug)] 28 | pub enum ParsingError { 29 | Incomplete, 30 | WouldOverflow, 31 | Other, 32 | } 33 | 34 | #[derive(Debug)] 35 | pub enum SerializeError { 36 | WouldOverflow, 37 | } 38 | 39 | async fn read_message(mut conn: &mut T, buf: &mut BytesMut) -> Result> 40 | where 41 | T: Read + Unpin, 42 | M: for<'a> TryFrom<&'a mut BytesMut, Error = ParsingError>, 43 | { 44 | let mut tmp = [0; 513]; 45 | 46 | loop { 47 | let n = crate::rt::read(&mut conn, &mut tmp).await?; 48 | buf.extend_from_slice(&tmp[..n]); 49 | 50 | match M::try_from(buf) { 51 | Err(ParsingError::Incomplete) => { 52 | if n == 0 { 53 | if buf.spare_capacity_mut().is_empty() { 54 | return Err(SocksError::Parsing(ParsingError::WouldOverflow)); 55 | } else { 56 | return Err(std::io::Error::new( 57 | std::io::ErrorKind::UnexpectedEof, 58 | "unexpected eof", 59 | ) 60 | .into()); 61 | } 62 | } 63 | } 64 | Err(err) => return Err(err.into()), 65 | Ok(res) => return Ok(res), 66 | } 67 | } 68 | } 69 | 70 | impl std::fmt::Display for SocksError { 71 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 72 | f.write_str("SOCKS error: ")?; 73 | 74 | match self { 75 | Self::Inner(_) => f.write_str("failed to create underlying connection"), 76 | Self::Io(_) => f.write_str("io error during SOCKS handshake"), 77 | 78 | Self::DnsFailure => f.write_str("could not resolve to acceptable address type"), 79 | Self::MissingHost => f.write_str("missing destination host"), 80 | Self::MissingPort => f.write_str("missing destination port"), 81 | 82 | Self::Parsing(_) => f.write_str("failed parsing server response"), 83 | Self::Serialize(_) => f.write_str("failed serialize request"), 84 | 85 | Self::V4(e) => e.fmt(f), 86 | Self::V5(e) => e.fmt(f), 87 | } 88 | } 89 | } 90 | 91 | impl std::error::Error for SocksError {} 92 | 93 | impl From for SocksError { 94 | fn from(err: std::io::Error) -> Self { 95 | Self::Io(err) 96 | } 97 | } 98 | 99 | impl From for SocksError { 100 | fn from(err: ParsingError) -> Self { 101 | Self::Parsing(err) 102 | } 103 | } 104 | 105 | impl From for SocksError { 106 | fn from(err: SerializeError) -> Self { 107 | Self::Serialize(err) 108 | } 109 | } 110 | 111 | impl From for SocksError { 112 | fn from(err: SocksV4Error) -> Self { 113 | Self::V4(err) 114 | } 115 | } 116 | 117 | impl From for SocksError { 118 | fn from(err: SocksV5Error) -> Self { 119 | Self::V5(err) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/client/legacy/connect/proxy/socks/v4/errors.rs: -------------------------------------------------------------------------------- 1 | use super::Status; 2 | 3 | #[derive(Debug)] 4 | pub enum SocksV4Error { 5 | IpV6, 6 | Command(Status), 7 | } 8 | 9 | impl From for SocksV4Error { 10 | fn from(err: Status) -> Self { 11 | Self::Command(err) 12 | } 13 | } 14 | 15 | impl std::fmt::Display for SocksV4Error { 16 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 17 | match self { 18 | Self::IpV6 => f.write_str("IPV6 is not supported"), 19 | Self::Command(status) => status.fmt(f), 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/client/legacy/connect/proxy/socks/v4/messages.rs: -------------------------------------------------------------------------------- 1 | use super::super::{ParsingError, SerializeError}; 2 | 3 | use bytes::{Buf, BufMut, BytesMut}; 4 | use std::net::SocketAddrV4; 5 | 6 | /// +-----+-----+----+----+----+----+----+----+-------------+------+------------+------+ 7 | /// | VN | CD | DSTPORT | DSTIP | USERID | NULL | DOMAIN | NULL | 8 | /// +-----+-----+----+----+----+----+----+----+-------------+------+------------+------+ 9 | /// | 1 | 1 | 2 | 4 | Variable | 1 | Variable | 1 | 10 | /// +-----+-----+----+----+----+----+----+----+-------------+------+------------+------+ 11 | /// ^^^^^^^^^^^^^^^^^^^^^ 12 | /// optional: only do IP is 0.0.0.X 13 | #[derive(Debug)] 14 | pub struct Request<'a>(pub &'a Address); 15 | 16 | /// +-----+-----+----+----+----+----+----+----+ 17 | /// | VN | CD | DSTPORT | DSTIP | 18 | /// +-----+-----+----+----+----+----+----+----+ 19 | /// | 1 | 1 | 2 | 4 | 20 | /// +-----+-----+----+----+----+----+----+----+ 21 | /// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 22 | /// ignore: only for SOCKSv4 BIND 23 | #[derive(Debug)] 24 | pub struct Response(pub Status); 25 | 26 | #[derive(Debug)] 27 | pub enum Address { 28 | Socket(SocketAddrV4), 29 | Domain(String, u16), 30 | } 31 | 32 | #[derive(Debug, PartialEq)] 33 | pub enum Status { 34 | Success = 90, 35 | Failed = 91, 36 | IdentFailure = 92, 37 | IdentMismatch = 93, 38 | } 39 | 40 | impl Request<'_> { 41 | pub fn write_to_buf(&self, mut buf: B) -> Result { 42 | match self.0 { 43 | Address::Socket(socket) => { 44 | if buf.remaining_mut() < 10 { 45 | return Err(SerializeError::WouldOverflow); 46 | } 47 | 48 | buf.put_u8(0x04); // Version 49 | buf.put_u8(0x01); // CONNECT 50 | 51 | buf.put_u16(socket.port()); // Port 52 | buf.put_slice(&socket.ip().octets()); // IP 53 | 54 | buf.put_u8(0x00); // USERID 55 | buf.put_u8(0x00); // NULL 56 | 57 | Ok(10) 58 | } 59 | 60 | Address::Domain(domain, port) => { 61 | if buf.remaining_mut() < 10 + domain.len() + 1 { 62 | return Err(SerializeError::WouldOverflow); 63 | } 64 | 65 | buf.put_u8(0x04); // Version 66 | buf.put_u8(0x01); // CONNECT 67 | 68 | buf.put_u16(*port); // IP 69 | buf.put_slice(&[0x00, 0x00, 0x00, 0xFF]); // Invalid IP 70 | 71 | buf.put_u8(0x00); // USERID 72 | buf.put_u8(0x00); // NULL 73 | 74 | buf.put_slice(domain.as_bytes()); // Domain 75 | buf.put_u8(0x00); // NULL 76 | 77 | Ok(10 + domain.len() + 1) 78 | } 79 | } 80 | } 81 | } 82 | 83 | impl TryFrom<&mut BytesMut> for Response { 84 | type Error = ParsingError; 85 | 86 | fn try_from(buf: &mut BytesMut) -> Result { 87 | if buf.remaining() < 8 { 88 | return Err(ParsingError::Incomplete); 89 | } 90 | 91 | if buf.get_u8() != 0x00 { 92 | return Err(ParsingError::Other); 93 | } 94 | 95 | let status = buf.get_u8().try_into()?; 96 | let _addr = { 97 | let port = buf.get_u16(); 98 | let mut ip = [0; 4]; 99 | buf.copy_to_slice(&mut ip); 100 | 101 | SocketAddrV4::new(ip.into(), port) 102 | }; 103 | 104 | Ok(Self(status)) 105 | } 106 | } 107 | 108 | impl TryFrom for Status { 109 | type Error = ParsingError; 110 | 111 | fn try_from(byte: u8) -> Result { 112 | Ok(match byte { 113 | 90 => Self::Success, 114 | 91 => Self::Failed, 115 | 92 => Self::IdentFailure, 116 | 93 => Self::IdentMismatch, 117 | _ => return Err(ParsingError::Other), 118 | }) 119 | } 120 | } 121 | 122 | impl std::fmt::Display for Status { 123 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 124 | f.write_str(match self { 125 | Self::Success => "success", 126 | Self::Failed => "server failed to execute command", 127 | Self::IdentFailure => "server ident service failed", 128 | Self::IdentMismatch => "server ident service did not recognise client identifier", 129 | }) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /src/client/legacy/connect/proxy/socks/v4/mod.rs: -------------------------------------------------------------------------------- 1 | mod errors; 2 | pub use errors::*; 3 | 4 | mod messages; 5 | use messages::*; 6 | 7 | use std::future::Future; 8 | use std::pin::Pin; 9 | use std::task::{Context, Poll}; 10 | 11 | use std::net::{IpAddr, SocketAddr, SocketAddrV4, ToSocketAddrs}; 12 | 13 | use http::Uri; 14 | use hyper::rt::{Read, Write}; 15 | use tower_service::Service; 16 | 17 | use bytes::BytesMut; 18 | 19 | use pin_project_lite::pin_project; 20 | 21 | /// Tunnel Proxy via SOCKSv4 22 | /// 23 | /// This is a connector that can be used by the `legacy::Client`. It wraps 24 | /// another connector, and after getting an underlying connection, it established 25 | /// a TCP tunnel over it using SOCKSv4. 26 | #[derive(Debug, Clone)] 27 | pub struct SocksV4 { 28 | inner: C, 29 | config: SocksConfig, 30 | } 31 | 32 | #[derive(Debug, Clone)] 33 | struct SocksConfig { 34 | proxy: Uri, 35 | local_dns: bool, 36 | } 37 | 38 | pin_project! { 39 | // Not publicly exported (so missing_docs doesn't trigger). 40 | // 41 | // We return this `Future` instead of the `Pin>` directly 42 | // so that users don't rely on it fitting in a `Pin>` slot 43 | // (and thus we can change the type in the future). 44 | #[must_use = "futures do nothing unless polled"] 45 | #[allow(missing_debug_implementations)] 46 | pub struct Handshaking { 47 | #[pin] 48 | fut: BoxHandshaking, 49 | _marker: std::marker::PhantomData 50 | } 51 | } 52 | 53 | type BoxHandshaking = Pin>> + Send>>; 54 | 55 | impl SocksV4 { 56 | /// Create a new SOCKSv4 handshake service 57 | /// 58 | /// Wraps an underlying connector and stores the address of a tunneling 59 | /// proxying server. 60 | /// 61 | /// A `SocksV4` can then be called with any destination. The `dst` passed to 62 | /// `call` will not be used to create the underlying connection, but will 63 | /// be used in a SOCKS handshake with the proxy destination. 64 | pub fn new(proxy_dst: Uri, connector: C) -> Self { 65 | Self { 66 | inner: connector, 67 | config: SocksConfig::new(proxy_dst), 68 | } 69 | } 70 | 71 | /// Resolve domain names locally on the client, rather than on the proxy server. 72 | /// 73 | /// Disabled by default as local resolution of domain names can be detected as a 74 | /// DNS leak. 75 | pub fn local_dns(mut self, local_dns: bool) -> Self { 76 | self.config.local_dns = local_dns; 77 | self 78 | } 79 | } 80 | 81 | impl SocksConfig { 82 | pub fn new(proxy: Uri) -> Self { 83 | Self { 84 | proxy, 85 | local_dns: false, 86 | } 87 | } 88 | 89 | async fn execute( 90 | self, 91 | mut conn: T, 92 | host: String, 93 | port: u16, 94 | ) -> Result> 95 | where 96 | T: Read + Write + Unpin, 97 | { 98 | let address = match host.parse::() { 99 | Ok(IpAddr::V6(_)) => return Err(SocksV4Error::IpV6.into()), 100 | Ok(IpAddr::V4(ip)) => Address::Socket(SocketAddrV4::new(ip, port)), 101 | Err(_) => { 102 | if self.local_dns { 103 | (host, port) 104 | .to_socket_addrs()? 105 | .find_map(|s| { 106 | if let SocketAddr::V4(v4) = s { 107 | Some(Address::Socket(v4)) 108 | } else { 109 | None 110 | } 111 | }) 112 | .ok_or(super::SocksError::DnsFailure)? 113 | } else { 114 | Address::Domain(host, port) 115 | } 116 | } 117 | }; 118 | 119 | let mut send_buf = BytesMut::with_capacity(1024); 120 | let mut recv_buf = BytesMut::with_capacity(1024); 121 | 122 | // Send Request 123 | let req = Request(&address); 124 | let n = req.write_to_buf(&mut send_buf)?; 125 | crate::rt::write_all(&mut conn, &send_buf[..n]).await?; 126 | 127 | // Read Response 128 | let res: Response = super::read_message(&mut conn, &mut recv_buf).await?; 129 | if res.0 == Status::Success { 130 | Ok(conn) 131 | } else { 132 | Err(SocksV4Error::Command(res.0).into()) 133 | } 134 | } 135 | } 136 | 137 | impl Service for SocksV4 138 | where 139 | C: Service, 140 | C::Future: Send + 'static, 141 | C::Response: Read + Write + Unpin + Send + 'static, 142 | C::Error: Send + 'static, 143 | { 144 | type Response = C::Response; 145 | type Error = super::SocksError; 146 | type Future = Handshaking; 147 | 148 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { 149 | self.inner.poll_ready(cx).map_err(super::SocksError::Inner) 150 | } 151 | 152 | fn call(&mut self, dst: Uri) -> Self::Future { 153 | let config = self.config.clone(); 154 | let connecting = self.inner.call(config.proxy.clone()); 155 | 156 | let fut = async move { 157 | let port = dst.port().map(|p| p.as_u16()).unwrap_or(443); 158 | let host = dst 159 | .host() 160 | .ok_or(super::SocksError::MissingHost)? 161 | .to_string(); 162 | 163 | let conn = connecting.await.map_err(super::SocksError::Inner)?; 164 | config.execute(conn, host, port).await 165 | }; 166 | 167 | Handshaking { 168 | fut: Box::pin(fut), 169 | _marker: Default::default(), 170 | } 171 | } 172 | } 173 | 174 | impl Future for Handshaking 175 | where 176 | F: Future>, 177 | { 178 | type Output = Result>; 179 | 180 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 181 | self.project().fut.poll(cx) 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /src/client/legacy/connect/proxy/socks/v5/errors.rs: -------------------------------------------------------------------------------- 1 | use super::Status; 2 | 3 | #[derive(Debug)] 4 | pub enum SocksV5Error { 5 | HostTooLong, 6 | Auth(AuthError), 7 | Command(Status), 8 | } 9 | 10 | #[derive(Debug)] 11 | pub enum AuthError { 12 | Unsupported, 13 | MethodMismatch, 14 | Failed, 15 | } 16 | 17 | impl From for SocksV5Error { 18 | fn from(err: Status) -> Self { 19 | Self::Command(err) 20 | } 21 | } 22 | 23 | impl From for SocksV5Error { 24 | fn from(err: AuthError) -> Self { 25 | Self::Auth(err) 26 | } 27 | } 28 | 29 | impl std::fmt::Display for SocksV5Error { 30 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 31 | match self { 32 | Self::HostTooLong => f.write_str("host address is more than 255 characters"), 33 | Self::Command(e) => e.fmt(f), 34 | Self::Auth(e) => e.fmt(f), 35 | } 36 | } 37 | } 38 | 39 | impl std::fmt::Display for AuthError { 40 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 41 | f.write_str(match self { 42 | Self::Unsupported => "server does not support user/pass authentication", 43 | Self::MethodMismatch => "server implements authentication incorrectly", 44 | Self::Failed => "credentials not accepted", 45 | }) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/client/legacy/connect/proxy/socks/v5/messages.rs: -------------------------------------------------------------------------------- 1 | use super::super::{ParsingError, SerializeError}; 2 | 3 | use bytes::{Buf, BufMut, BytesMut}; 4 | use std::net::SocketAddr; 5 | 6 | /// +----+----------+----------+ 7 | /// |VER | NMETHODS | METHODS | 8 | /// +----+----------+----------+ 9 | /// | 1 | 1 | 1 to 255 | 10 | /// +----+----------+----------+ 11 | #[derive(Debug)] 12 | pub struct NegotiationReq<'a>(pub &'a AuthMethod); 13 | 14 | /// +----+--------+ 15 | /// |VER | METHOD | 16 | /// +----+--------+ 17 | /// | 1 | 1 | 18 | /// +----+--------+ 19 | #[derive(Debug)] 20 | pub struct NegotiationRes(pub AuthMethod); 21 | 22 | /// +----+------+----------+------+----------+ 23 | /// |VER | ULEN | UNAME | PLEN | PASSWD | 24 | /// +----+------+----------+------+----------+ 25 | /// | 1 | 1 | 1 to 255 | 1 | 1 to 255 | 26 | /// +----+------+----------+------+----------+ 27 | #[derive(Debug)] 28 | pub struct AuthenticationReq<'a>(pub &'a str, pub &'a str); 29 | 30 | /// +----+--------+ 31 | /// |VER | STATUS | 32 | /// +----+--------+ 33 | /// | 1 | 1 | 34 | /// +----+--------+ 35 | #[derive(Debug)] 36 | pub struct AuthenticationRes(pub bool); 37 | 38 | /// +----+-----+-------+------+----------+----------+ 39 | /// |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | 40 | /// +----+-----+-------+------+----------+----------+ 41 | /// | 1 | 1 | X'00' | 1 | Variable | 2 | 42 | /// +----+-----+-------+------+----------+----------+ 43 | #[derive(Debug)] 44 | pub struct ProxyReq<'a>(pub &'a Address); 45 | 46 | /// +----+-----+-------+------+----------+----------+ 47 | /// |VER | REP | RSV | ATYP | BND.ADDR | BND.PORT | 48 | /// +----+-----+-------+------+----------+----------+ 49 | /// | 1 | 1 | X'00' | 1 | Variable | 2 | 50 | /// +----+-----+-------+------+----------+----------+ 51 | #[derive(Debug)] 52 | pub struct ProxyRes(pub Status); 53 | 54 | #[repr(u8)] 55 | #[derive(Debug, Copy, Clone, PartialEq)] 56 | pub enum AuthMethod { 57 | NoAuth = 0x00, 58 | UserPass = 0x02, 59 | NoneAcceptable = 0xFF, 60 | } 61 | 62 | #[derive(Debug)] 63 | pub enum Address { 64 | Socket(SocketAddr), 65 | Domain(String, u16), 66 | } 67 | 68 | #[derive(Debug, Copy, Clone, PartialEq)] 69 | pub enum Status { 70 | Success, 71 | GeneralServerFailure, 72 | ConnectionNotAllowed, 73 | NetworkUnreachable, 74 | HostUnreachable, 75 | ConnectionRefused, 76 | TtlExpired, 77 | CommandNotSupported, 78 | AddressTypeNotSupported, 79 | } 80 | 81 | impl NegotiationReq<'_> { 82 | pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result { 83 | if buf.capacity() - buf.len() < 3 { 84 | return Err(SerializeError::WouldOverflow); 85 | } 86 | 87 | buf.put_u8(0x05); // Version 88 | buf.put_u8(0x01); // Number of authentication methods 89 | buf.put_u8(*self.0 as u8); // Authentication method 90 | 91 | Ok(3) 92 | } 93 | } 94 | 95 | impl TryFrom<&mut BytesMut> for NegotiationRes { 96 | type Error = ParsingError; 97 | 98 | fn try_from(buf: &mut BytesMut) -> Result { 99 | if buf.remaining() < 2 { 100 | return Err(ParsingError::Incomplete); 101 | } 102 | 103 | if buf.get_u8() != 0x05 { 104 | return Err(ParsingError::Other); 105 | } 106 | 107 | let method = buf.get_u8().try_into()?; 108 | Ok(Self(method)) 109 | } 110 | } 111 | 112 | impl AuthenticationReq<'_> { 113 | pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result { 114 | if buf.capacity() - buf.len() < 3 + self.0.len() + self.1.len() { 115 | return Err(SerializeError::WouldOverflow); 116 | } 117 | 118 | buf.put_u8(0x01); // Version 119 | 120 | buf.put_u8(self.0.len() as u8); // Username length (guarenteed to be 255 or less) 121 | buf.put_slice(self.0.as_bytes()); // Username 122 | 123 | buf.put_u8(self.1.len() as u8); // Password length (guarenteed to be 255 or less) 124 | buf.put_slice(self.1.as_bytes()); // Password 125 | 126 | Ok(3 + self.0.len() + self.1.len()) 127 | } 128 | } 129 | 130 | impl TryFrom<&mut BytesMut> for AuthenticationRes { 131 | type Error = ParsingError; 132 | 133 | fn try_from(buf: &mut BytesMut) -> Result { 134 | if buf.remaining() < 2 { 135 | return Err(ParsingError::Incomplete); 136 | } 137 | 138 | if buf.get_u8() != 0x01 { 139 | return Err(ParsingError::Other); 140 | } 141 | 142 | if buf.get_u8() == 0 { 143 | Ok(Self(true)) 144 | } else { 145 | Ok(Self(false)) 146 | } 147 | } 148 | } 149 | 150 | impl ProxyReq<'_> { 151 | pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result { 152 | let addr_len = match self.0 { 153 | Address::Socket(SocketAddr::V4(_)) => 1 + 4 + 2, 154 | Address::Socket(SocketAddr::V6(_)) => 1 + 16 + 2, 155 | Address::Domain(ref domain, _) => 1 + 1 + domain.len() + 2, 156 | }; 157 | 158 | if buf.capacity() - buf.len() < 3 + addr_len { 159 | return Err(SerializeError::WouldOverflow); 160 | } 161 | 162 | buf.put_u8(0x05); // Version 163 | buf.put_u8(0x01); // TCP tunneling command 164 | buf.put_u8(0x00); // Reserved 165 | let _ = self.0.write_to_buf(buf); // Address 166 | 167 | Ok(3 + addr_len) 168 | } 169 | } 170 | 171 | impl TryFrom<&mut BytesMut> for ProxyRes { 172 | type Error = ParsingError; 173 | 174 | fn try_from(buf: &mut BytesMut) -> Result { 175 | if buf.remaining() < 2 { 176 | return Err(ParsingError::Incomplete); 177 | } 178 | 179 | // VER 180 | if buf.get_u8() != 0x05 { 181 | return Err(ParsingError::Other); 182 | } 183 | 184 | // REP 185 | let status = buf.get_u8().try_into()?; 186 | 187 | // RSV 188 | if buf.get_u8() != 0x00 { 189 | return Err(ParsingError::Other); 190 | } 191 | 192 | // ATYP + ADDR 193 | Address::try_from(buf)?; 194 | 195 | Ok(Self(status)) 196 | } 197 | } 198 | 199 | impl Address { 200 | pub fn write_to_buf(&self, buf: &mut BytesMut) -> Result { 201 | match self { 202 | Self::Socket(SocketAddr::V4(v4)) => { 203 | if buf.capacity() - buf.len() < 1 + 4 + 2 { 204 | return Err(SerializeError::WouldOverflow); 205 | } 206 | 207 | buf.put_u8(0x01); 208 | buf.put_slice(&v4.ip().octets()); 209 | buf.put_u16(v4.port()); // Network Order/BigEndian for port 210 | 211 | Ok(7) 212 | } 213 | 214 | Self::Socket(SocketAddr::V6(v6)) => { 215 | if buf.capacity() - buf.len() < 1 + 16 + 2 { 216 | return Err(SerializeError::WouldOverflow); 217 | } 218 | 219 | buf.put_u8(0x04); 220 | buf.put_slice(&v6.ip().octets()); 221 | buf.put_u16(v6.port()); // Network Order/BigEndian for port 222 | 223 | Ok(19) 224 | } 225 | 226 | Self::Domain(domain, port) => { 227 | if buf.capacity() - buf.len() < 1 + 1 + domain.len() + 2 { 228 | return Err(SerializeError::WouldOverflow); 229 | } 230 | 231 | buf.put_u8(0x03); 232 | buf.put_u8(domain.len() as u8); // Guarenteed to be less than 255 233 | buf.put_slice(domain.as_bytes()); 234 | buf.put_u16(*port); 235 | 236 | Ok(4 + domain.len()) 237 | } 238 | } 239 | } 240 | } 241 | 242 | impl TryFrom<&mut BytesMut> for Address { 243 | type Error = ParsingError; 244 | 245 | fn try_from(buf: &mut BytesMut) -> Result { 246 | if buf.remaining() < 2 { 247 | return Err(ParsingError::Incomplete); 248 | } 249 | 250 | Ok(match buf.get_u8() { 251 | 0x01 => { 252 | let mut ip = [0; 4]; 253 | 254 | if buf.remaining() < 6 { 255 | return Err(ParsingError::Incomplete); 256 | } 257 | 258 | buf.copy_to_slice(&mut ip); 259 | let port = buf.get_u16(); 260 | 261 | Self::Socket(SocketAddr::new(ip.into(), port)) 262 | } 263 | 264 | 0x03 => { 265 | let len = buf.get_u8(); 266 | 267 | if len == 0 { 268 | return Err(ParsingError::Other); 269 | } else if buf.remaining() < (len as usize) + 2 { 270 | return Err(ParsingError::Incomplete); 271 | } 272 | 273 | let domain = std::str::from_utf8(&buf[..len as usize]) 274 | .map_err(|_| ParsingError::Other)? 275 | .to_string(); 276 | 277 | let port = buf.get_u16(); 278 | 279 | Self::Domain(domain, port) 280 | } 281 | 282 | 0x04 => { 283 | let mut ip = [0; 16]; 284 | 285 | if buf.remaining() < 6 { 286 | return Err(ParsingError::Incomplete); 287 | } 288 | buf.copy_to_slice(&mut ip); 289 | let port = buf.get_u16(); 290 | 291 | Self::Socket(SocketAddr::new(ip.into(), port)) 292 | } 293 | 294 | _ => return Err(ParsingError::Other), 295 | }) 296 | } 297 | } 298 | 299 | impl TryFrom for Status { 300 | type Error = ParsingError; 301 | 302 | fn try_from(byte: u8) -> Result { 303 | Ok(match byte { 304 | 0x00 => Self::Success, 305 | 306 | 0x01 => Self::GeneralServerFailure, 307 | 0x02 => Self::ConnectionNotAllowed, 308 | 0x03 => Self::NetworkUnreachable, 309 | 0x04 => Self::HostUnreachable, 310 | 0x05 => Self::ConnectionRefused, 311 | 0x06 => Self::TtlExpired, 312 | 0x07 => Self::CommandNotSupported, 313 | 0x08 => Self::AddressTypeNotSupported, 314 | _ => return Err(ParsingError::Other), 315 | }) 316 | } 317 | } 318 | 319 | impl TryFrom for AuthMethod { 320 | type Error = ParsingError; 321 | 322 | fn try_from(byte: u8) -> Result { 323 | Ok(match byte { 324 | 0x00 => Self::NoAuth, 325 | 0x02 => Self::UserPass, 326 | 0xFF => Self::NoneAcceptable, 327 | 328 | _ => return Err(ParsingError::Other), 329 | }) 330 | } 331 | } 332 | 333 | impl std::fmt::Display for Status { 334 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 335 | f.write_str(match self { 336 | Self::Success => "success", 337 | Self::GeneralServerFailure => "general server failure", 338 | Self::ConnectionNotAllowed => "connection not allowed", 339 | Self::NetworkUnreachable => "network unreachable", 340 | Self::HostUnreachable => "host unreachable", 341 | Self::ConnectionRefused => "connection refused", 342 | Self::TtlExpired => "ttl expired", 343 | Self::CommandNotSupported => "command not supported", 344 | Self::AddressTypeNotSupported => "address type not supported", 345 | }) 346 | } 347 | } 348 | -------------------------------------------------------------------------------- /src/client/legacy/connect/proxy/socks/v5/mod.rs: -------------------------------------------------------------------------------- 1 | mod errors; 2 | pub use errors::*; 3 | 4 | mod messages; 5 | use messages::*; 6 | 7 | use std::future::Future; 8 | use std::pin::Pin; 9 | use std::task::{Context, Poll}; 10 | 11 | use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; 12 | 13 | use http::Uri; 14 | use hyper::rt::{Read, Write}; 15 | use tower_service::Service; 16 | 17 | use bytes::BytesMut; 18 | 19 | use pin_project_lite::pin_project; 20 | 21 | /// Tunnel Proxy via SOCKSv5 22 | /// 23 | /// This is a connector that can be used by the `legacy::Client`. It wraps 24 | /// another connector, and after getting an underlying connection, it established 25 | /// a TCP tunnel over it using SOCKSv5. 26 | #[derive(Debug, Clone)] 27 | pub struct SocksV5 { 28 | inner: C, 29 | config: SocksConfig, 30 | } 31 | 32 | #[derive(Debug, Clone)] 33 | pub struct SocksConfig { 34 | proxy: Uri, 35 | proxy_auth: Option<(String, String)>, 36 | 37 | local_dns: bool, 38 | optimistic: bool, 39 | } 40 | 41 | #[derive(Debug)] 42 | enum State { 43 | SendingNegReq, 44 | ReadingNegRes, 45 | SendingAuthReq, 46 | ReadingAuthRes, 47 | SendingProxyReq, 48 | ReadingProxyRes, 49 | } 50 | 51 | pin_project! { 52 | // Not publicly exported (so missing_docs doesn't trigger). 53 | // 54 | // We return this `Future` instead of the `Pin>` directly 55 | // so that users don't rely on it fitting in a `Pin>` slot 56 | // (and thus we can change the type in the future). 57 | #[must_use = "futures do nothing unless polled"] 58 | #[allow(missing_debug_implementations)] 59 | pub struct Handshaking { 60 | #[pin] 61 | fut: BoxHandshaking, 62 | _marker: std::marker::PhantomData 63 | } 64 | } 65 | 66 | type BoxHandshaking = Pin>> + Send>>; 67 | 68 | impl SocksV5 { 69 | /// Create a new SOCKSv5 handshake service. 70 | /// 71 | /// Wraps an underlying connector and stores the address of a tunneling 72 | /// proxying server. 73 | /// 74 | /// A `SocksV5` can then be called with any destination. The `dst` passed to 75 | /// `call` will not be used to create the underlying connection, but will 76 | /// be used in a SOCKS handshake with the proxy destination. 77 | pub fn new(proxy_dst: Uri, connector: C) -> Self { 78 | Self { 79 | inner: connector, 80 | config: SocksConfig::new(proxy_dst), 81 | } 82 | } 83 | 84 | /// Use User/Pass authentication method during handshake. 85 | /// 86 | /// Username and Password must be maximum of 255 characters each. 87 | /// 0 length strings are allowed despite RFC prohibiting it. This is done so that 88 | /// for compatablity with server implementations that require it for IP authentication. 89 | pub fn with_auth(mut self, user: String, pass: String) -> Self { 90 | self.config.proxy_auth = Some((user, pass)); 91 | self 92 | } 93 | 94 | /// Resolve domain names locally on the client, rather than on the proxy server. 95 | /// 96 | /// Disabled by default as local resolution of domain names can be detected as a 97 | /// DNS leak. 98 | pub fn local_dns(mut self, local_dns: bool) -> Self { 99 | self.config.local_dns = local_dns; 100 | self 101 | } 102 | 103 | /// Send all messages of the handshake optmistically (without waiting for server response). 104 | /// 105 | /// Typical SOCKS handshake with auithentication takes 3 round trips. Optimistic sending 106 | /// can reduce round trip times and dramatically increase speed of handshake at the cost of 107 | /// reduced portability; many server implementations do not support optimistic sending as it 108 | /// is not defined in the RFC (RFC 1928). 109 | /// 110 | /// Recommended to ensure connector works correctly without optimistic sending before trying 111 | /// with optimistic sending. 112 | pub fn send_optimistically(mut self, optimistic: bool) -> Self { 113 | self.config.optimistic = optimistic; 114 | self 115 | } 116 | } 117 | 118 | impl SocksConfig { 119 | fn new(proxy: Uri) -> Self { 120 | Self { 121 | proxy, 122 | proxy_auth: None, 123 | 124 | local_dns: false, 125 | optimistic: false, 126 | } 127 | } 128 | 129 | async fn execute( 130 | self, 131 | mut conn: T, 132 | host: String, 133 | port: u16, 134 | ) -> Result> 135 | where 136 | T: Read + Write + Unpin, 137 | { 138 | let address = match host.parse::() { 139 | Ok(ip) => Address::Socket(SocketAddr::new(ip, port)), 140 | Err(_) if host.len() <= 255 => { 141 | if self.local_dns { 142 | let socket = (host, port) 143 | .to_socket_addrs()? 144 | .next() 145 | .ok_or(super::SocksError::DnsFailure)?; 146 | 147 | Address::Socket(socket) 148 | } else { 149 | Address::Domain(host, port) 150 | } 151 | } 152 | Err(_) => return Err(SocksV5Error::HostTooLong.into()), 153 | }; 154 | 155 | let method = if self.proxy_auth.is_some() { 156 | AuthMethod::UserPass 157 | } else { 158 | AuthMethod::NoAuth 159 | }; 160 | 161 | let mut recv_buf = BytesMut::with_capacity(513); // Max length of valid recievable message is 513 from Auth Request 162 | let mut send_buf = BytesMut::with_capacity(262); // Max length of valid sendable message is 262 from Auth Response 163 | let mut state = State::SendingNegReq; 164 | 165 | loop { 166 | match state { 167 | State::SendingNegReq => { 168 | let req = NegotiationReq(&method); 169 | 170 | let start = send_buf.len(); 171 | req.write_to_buf(&mut send_buf)?; 172 | crate::rt::write_all(&mut conn, &send_buf[start..]).await?; 173 | 174 | if self.optimistic { 175 | if method == AuthMethod::UserPass { 176 | state = State::SendingAuthReq; 177 | } else { 178 | state = State::SendingProxyReq; 179 | } 180 | } else { 181 | state = State::ReadingNegRes; 182 | } 183 | } 184 | 185 | State::ReadingNegRes => { 186 | let res: NegotiationRes = super::read_message(&mut conn, &mut recv_buf).await?; 187 | 188 | if res.0 == AuthMethod::NoneAcceptable { 189 | return Err(SocksV5Error::Auth(AuthError::Unsupported).into()); 190 | } 191 | 192 | if res.0 != method { 193 | return Err(SocksV5Error::Auth(AuthError::MethodMismatch).into()); 194 | } 195 | 196 | if self.optimistic { 197 | if res.0 == AuthMethod::UserPass { 198 | state = State::ReadingAuthRes; 199 | } else { 200 | state = State::ReadingProxyRes; 201 | } 202 | } else if res.0 == AuthMethod::UserPass { 203 | state = State::SendingAuthReq; 204 | } else { 205 | state = State::SendingProxyReq; 206 | } 207 | } 208 | 209 | State::SendingAuthReq => { 210 | let (user, pass) = self.proxy_auth.as_ref().unwrap(); 211 | let req = AuthenticationReq(user, pass); 212 | 213 | let start = send_buf.len(); 214 | req.write_to_buf(&mut send_buf)?; 215 | crate::rt::write_all(&mut conn, &send_buf[start..]).await?; 216 | 217 | if self.optimistic { 218 | state = State::SendingProxyReq; 219 | } else { 220 | state = State::ReadingAuthRes; 221 | } 222 | } 223 | 224 | State::ReadingAuthRes => { 225 | let res: AuthenticationRes = 226 | super::read_message(&mut conn, &mut recv_buf).await?; 227 | 228 | if !res.0 { 229 | return Err(SocksV5Error::Auth(AuthError::Failed).into()); 230 | } 231 | 232 | if self.optimistic { 233 | state = State::ReadingProxyRes; 234 | } else { 235 | state = State::SendingProxyReq; 236 | } 237 | } 238 | 239 | State::SendingProxyReq => { 240 | let req = ProxyReq(&address); 241 | 242 | let start = send_buf.len(); 243 | req.write_to_buf(&mut send_buf)?; 244 | crate::rt::write_all(&mut conn, &send_buf[start..]).await?; 245 | 246 | if self.optimistic { 247 | state = State::ReadingNegRes; 248 | } else { 249 | state = State::ReadingProxyRes; 250 | } 251 | } 252 | 253 | State::ReadingProxyRes => { 254 | let res: ProxyRes = super::read_message(&mut conn, &mut recv_buf).await?; 255 | 256 | if res.0 == Status::Success { 257 | return Ok(conn); 258 | } else { 259 | return Err(SocksV5Error::Command(res.0).into()); 260 | } 261 | } 262 | } 263 | } 264 | } 265 | } 266 | 267 | impl Service for SocksV5 268 | where 269 | C: Service, 270 | C::Future: Send + 'static, 271 | C::Response: Read + Write + Unpin + Send + 'static, 272 | C::Error: Send + 'static, 273 | { 274 | type Response = C::Response; 275 | type Error = super::SocksError; 276 | type Future = Handshaking; 277 | 278 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { 279 | self.inner.poll_ready(cx).map_err(super::SocksError::Inner) 280 | } 281 | 282 | fn call(&mut self, dst: Uri) -> Self::Future { 283 | let config = self.config.clone(); 284 | let connecting = self.inner.call(config.proxy.clone()); 285 | 286 | let fut = async move { 287 | let port = dst.port().map(|p| p.as_u16()).unwrap_or(443); 288 | let host = dst 289 | .host() 290 | .ok_or(super::SocksError::MissingHost)? 291 | .to_string(); 292 | 293 | let conn = connecting.await.map_err(super::SocksError::Inner)?; 294 | config.execute(conn, host, port).await 295 | }; 296 | 297 | Handshaking { 298 | fut: Box::pin(fut), 299 | _marker: Default::default(), 300 | } 301 | } 302 | } 303 | 304 | impl Future for Handshaking 305 | where 306 | F: Future>, 307 | { 308 | type Output = Result>; 309 | 310 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 311 | self.project().fut.poll(cx) 312 | } 313 | } 314 | -------------------------------------------------------------------------------- /src/client/legacy/connect/proxy/tunnel.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error as StdError; 2 | use std::future::Future; 3 | use std::marker::{PhantomData, Unpin}; 4 | use std::pin::Pin; 5 | use std::task::{self, Poll}; 6 | 7 | use futures_core::ready; 8 | use http::{HeaderMap, HeaderValue, Uri}; 9 | use hyper::rt::{Read, Write}; 10 | use pin_project_lite::pin_project; 11 | use tower_service::Service; 12 | 13 | /// Tunnel Proxy via HTTP CONNECT 14 | /// 15 | /// This is a connector that can be used by the `legacy::Client`. It wraps 16 | /// another connector, and after getting an underlying connection, it creates 17 | /// an HTTP CONNECT tunnel over it. 18 | #[derive(Debug)] 19 | pub struct Tunnel { 20 | headers: Headers, 21 | inner: C, 22 | proxy_dst: Uri, 23 | } 24 | 25 | #[derive(Clone, Debug)] 26 | enum Headers { 27 | Empty, 28 | Auth(HeaderValue), 29 | Extra(HeaderMap), 30 | } 31 | 32 | #[derive(Debug)] 33 | pub enum TunnelError { 34 | ConnectFailed(Box), 35 | Io(std::io::Error), 36 | MissingHost, 37 | ProxyAuthRequired, 38 | ProxyHeadersTooLong, 39 | TunnelUnexpectedEof, 40 | TunnelUnsuccessful, 41 | } 42 | 43 | pin_project! { 44 | // Not publicly exported (so missing_docs doesn't trigger). 45 | // 46 | // We return this `Future` instead of the `Pin>` directly 47 | // so that users don't rely on it fitting in a `Pin>` slot 48 | // (and thus we can change the type in the future). 49 | #[must_use = "futures do nothing unless polled"] 50 | #[allow(missing_debug_implementations)] 51 | pub struct Tunneling { 52 | #[pin] 53 | fut: BoxTunneling, 54 | _marker: PhantomData, 55 | } 56 | } 57 | 58 | type BoxTunneling = Pin> + Send>>; 59 | 60 | impl Tunnel { 61 | /// Create a new Tunnel service. 62 | /// 63 | /// This wraps an underlying connector, and stores the address of a 64 | /// tunneling proxy server. 65 | /// 66 | /// A `Tunnel` can then be called with any destination. The `dst` passed to 67 | /// `call` will not be used to create the underlying connection, but will 68 | /// be used in an HTTP CONNECT request sent to the proxy destination. 69 | pub fn new(proxy_dst: Uri, connector: C) -> Self { 70 | Self { 71 | headers: Headers::Empty, 72 | inner: connector, 73 | proxy_dst, 74 | } 75 | } 76 | 77 | /// Add `proxy-authorization` header value to the CONNECT request. 78 | pub fn with_auth(mut self, mut auth: HeaderValue) -> Self { 79 | // just in case the user forgot 80 | auth.set_sensitive(true); 81 | match self.headers { 82 | Headers::Empty => { 83 | self.headers = Headers::Auth(auth); 84 | } 85 | Headers::Auth(ref mut existing) => { 86 | *existing = auth; 87 | } 88 | Headers::Extra(ref mut extra) => { 89 | extra.insert(http::header::PROXY_AUTHORIZATION, auth); 90 | } 91 | } 92 | 93 | self 94 | } 95 | 96 | /// Add extra headers to be sent with the CONNECT request. 97 | /// 98 | /// If existing headers have been set, these will be merged. 99 | pub fn with_headers(mut self, mut headers: HeaderMap) -> Self { 100 | match self.headers { 101 | Headers::Empty => { 102 | self.headers = Headers::Extra(headers); 103 | } 104 | Headers::Auth(auth) => { 105 | headers 106 | .entry(http::header::PROXY_AUTHORIZATION) 107 | .or_insert(auth); 108 | self.headers = Headers::Extra(headers); 109 | } 110 | Headers::Extra(ref mut extra) => { 111 | extra.extend(headers); 112 | } 113 | } 114 | 115 | self 116 | } 117 | } 118 | 119 | impl Service for Tunnel 120 | where 121 | C: Service, 122 | C::Future: Send + 'static, 123 | C::Response: Read + Write + Unpin + Send + 'static, 124 | C::Error: Into>, 125 | { 126 | type Response = C::Response; 127 | type Error = TunnelError; 128 | type Future = Tunneling; 129 | 130 | fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { 131 | ready!(self.inner.poll_ready(cx)).map_err(|e| TunnelError::ConnectFailed(e.into()))?; 132 | Poll::Ready(Ok(())) 133 | } 134 | 135 | fn call(&mut self, dst: Uri) -> Self::Future { 136 | let connecting = self.inner.call(self.proxy_dst.clone()); 137 | let headers = self.headers.clone(); 138 | 139 | Tunneling { 140 | fut: Box::pin(async move { 141 | let conn = connecting 142 | .await 143 | .map_err(|e| TunnelError::ConnectFailed(e.into()))?; 144 | tunnel( 145 | conn, 146 | dst.host().ok_or(TunnelError::MissingHost)?, 147 | dst.port().map(|p| p.as_u16()).unwrap_or(443), 148 | &headers, 149 | ) 150 | .await 151 | }), 152 | _marker: PhantomData, 153 | } 154 | } 155 | } 156 | 157 | impl Future for Tunneling 158 | where 159 | F: Future>, 160 | { 161 | type Output = Result; 162 | 163 | fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { 164 | self.project().fut.poll(cx) 165 | } 166 | } 167 | 168 | async fn tunnel(mut conn: T, host: &str, port: u16, headers: &Headers) -> Result 169 | where 170 | T: Read + Write + Unpin, 171 | { 172 | let mut buf = format!( 173 | "\ 174 | CONNECT {host}:{port} HTTP/1.1\r\n\ 175 | Host: {host}:{port}\r\n\ 176 | " 177 | ) 178 | .into_bytes(); 179 | 180 | match headers { 181 | Headers::Auth(auth) => { 182 | buf.extend_from_slice(b"Proxy-Authorization: "); 183 | buf.extend_from_slice(auth.as_bytes()); 184 | buf.extend_from_slice(b"\r\n"); 185 | } 186 | Headers::Extra(extra) => { 187 | for (name, value) in extra { 188 | buf.extend_from_slice(name.as_str().as_bytes()); 189 | buf.extend_from_slice(b": "); 190 | buf.extend_from_slice(value.as_bytes()); 191 | buf.extend_from_slice(b"\r\n"); 192 | } 193 | } 194 | Headers::Empty => (), 195 | } 196 | 197 | // headers end 198 | buf.extend_from_slice(b"\r\n"); 199 | 200 | crate::rt::write_all(&mut conn, &buf) 201 | .await 202 | .map_err(TunnelError::Io)?; 203 | 204 | let mut buf = [0; 8192]; 205 | let mut pos = 0; 206 | 207 | loop { 208 | let n = crate::rt::read(&mut conn, &mut buf[pos..]) 209 | .await 210 | .map_err(TunnelError::Io)?; 211 | 212 | if n == 0 { 213 | return Err(TunnelError::TunnelUnexpectedEof); 214 | } 215 | pos += n; 216 | 217 | let recvd = &buf[..pos]; 218 | if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") { 219 | if recvd.ends_with(b"\r\n\r\n") { 220 | return Ok(conn); 221 | } 222 | if pos == buf.len() { 223 | return Err(TunnelError::ProxyHeadersTooLong); 224 | } 225 | // else read more 226 | } else if recvd.starts_with(b"HTTP/1.1 407") { 227 | return Err(TunnelError::ProxyAuthRequired); 228 | } else { 229 | return Err(TunnelError::TunnelUnsuccessful); 230 | } 231 | } 232 | } 233 | 234 | impl std::fmt::Display for TunnelError { 235 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 236 | f.write_str("tunnel error: ")?; 237 | 238 | f.write_str(match self { 239 | TunnelError::MissingHost => "missing destination host", 240 | TunnelError::ProxyAuthRequired => "proxy authorization required", 241 | TunnelError::ProxyHeadersTooLong => "proxy response headers too long", 242 | TunnelError::TunnelUnexpectedEof => "unexpected end of file", 243 | TunnelError::TunnelUnsuccessful => "unsuccessful", 244 | TunnelError::ConnectFailed(_) => "failed to create underlying connection", 245 | TunnelError::Io(_) => "io error establishing tunnel", 246 | }) 247 | } 248 | } 249 | 250 | impl std::error::Error for TunnelError { 251 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 252 | match self { 253 | TunnelError::Io(ref e) => Some(e), 254 | TunnelError::ConnectFailed(ref e) => Some(&**e), 255 | _ => None, 256 | } 257 | } 258 | } 259 | -------------------------------------------------------------------------------- /src/client/legacy/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(feature = "http1", feature = "http2"))] 2 | mod client; 3 | #[cfg(any(feature = "http1", feature = "http2"))] 4 | pub use client::{Builder, Client, Error, ResponseFuture}; 5 | 6 | pub mod connect; 7 | #[doc(hidden)] 8 | // Publicly available, but just for legacy purposes. A better pool will be 9 | // designed. 10 | pub mod pool; 11 | -------------------------------------------------------------------------------- /src/client/mod.rs: -------------------------------------------------------------------------------- 1 | //! HTTP client utilities 2 | 3 | /// Legacy implementations of `connect` module and `Client` 4 | #[cfg(feature = "client-legacy")] 5 | pub mod legacy; 6 | 7 | #[cfg(feature = "client-proxy")] 8 | pub mod proxy; 9 | -------------------------------------------------------------------------------- /src/client/proxy/matcher.rs: -------------------------------------------------------------------------------- 1 | //! Proxy matchers 2 | //! 3 | //! This module contains different matchers to configure rules for when a proxy 4 | //! should be used, and if so, with what arguments. 5 | //! 6 | //! A [`Matcher`] can be constructed either using environment variables, or 7 | //! a [`Matcher::builder()`]. 8 | //! 9 | //! Once constructed, the `Matcher` can be asked if it intercepts a `Uri` by 10 | //! calling [`Matcher::intercept()`]. 11 | //! 12 | //! An [`Intercept`] includes the destination for the proxy, and any parsed 13 | //! authentication to be used. 14 | 15 | use std::fmt; 16 | use std::net::IpAddr; 17 | 18 | use http::header::HeaderValue; 19 | use ipnet::IpNet; 20 | use percent_encoding::percent_decode_str; 21 | 22 | #[cfg(docsrs)] 23 | pub use self::builder::IntoValue; 24 | #[cfg(not(docsrs))] 25 | use self::builder::IntoValue; 26 | 27 | /// A proxy matcher, usually built from environment variables. 28 | pub struct Matcher { 29 | http: Option, 30 | https: Option, 31 | no: NoProxy, 32 | } 33 | 34 | /// A matched proxy, 35 | /// 36 | /// This is returned by a matcher if a proxy should be used. 37 | #[derive(Clone)] 38 | pub struct Intercept { 39 | uri: http::Uri, 40 | auth: Auth, 41 | } 42 | 43 | /// A builder to create a [`Matcher`]. 44 | /// 45 | /// Construct with [`Matcher::builder()`]. 46 | #[derive(Default)] 47 | pub struct Builder { 48 | is_cgi: bool, 49 | all: String, 50 | http: String, 51 | https: String, 52 | no: String, 53 | } 54 | 55 | #[derive(Clone)] 56 | enum Auth { 57 | Empty, 58 | Basic(http::header::HeaderValue), 59 | Raw(String, String), 60 | } 61 | 62 | /// A filter for proxy matchers. 63 | /// 64 | /// This type is based off the `NO_PROXY` rules used by curl. 65 | #[derive(Clone, Debug, Default)] 66 | struct NoProxy { 67 | ips: IpMatcher, 68 | domains: DomainMatcher, 69 | } 70 | 71 | #[derive(Clone, Debug, Default)] 72 | struct DomainMatcher(Vec); 73 | 74 | #[derive(Clone, Debug, Default)] 75 | struct IpMatcher(Vec); 76 | 77 | #[derive(Clone, Debug)] 78 | enum Ip { 79 | Address(IpAddr), 80 | Network(IpNet), 81 | } 82 | 83 | // ===== impl Matcher ===== 84 | 85 | impl Matcher { 86 | /// Create a matcher reading the current environment variables. 87 | /// 88 | /// This checks for values in the following variables, treating them the 89 | /// same as curl does: 90 | /// 91 | /// - `ALL_PROXY`/`all_proxy` 92 | /// - `HTTPS_PROXY`/`https_proxy` 93 | /// - `HTTP_PROXY`/`http_proxy` 94 | /// - `NO_PROXY`/`no_proxy` 95 | pub fn from_env() -> Self { 96 | Builder::from_env().build() 97 | } 98 | 99 | /// Create a matcher from the environment or system. 100 | /// 101 | /// This checks the same environment variables as `from_env()`, and if not 102 | /// set, checks the system configuration for values for the OS. 103 | /// 104 | /// This constructor is always available, but if the `client-proxy-system` 105 | /// feature is enabled, it will check more configuration. Use this 106 | /// constructor if you want to allow users to optionally enable more, or 107 | /// use `from_env` if you do not want the values to change based on an 108 | /// enabled feature. 109 | pub fn from_system() -> Self { 110 | Builder::from_system().build() 111 | } 112 | 113 | /// Start a builder to configure a matcher. 114 | pub fn builder() -> Builder { 115 | Builder::default() 116 | } 117 | 118 | /// Check if the destination should be intercepted by a proxy. 119 | /// 120 | /// If the proxy rules match the destination, a new `Uri` will be returned 121 | /// to connect to. 122 | pub fn intercept(&self, dst: &http::Uri) -> Option { 123 | // TODO(perf): don't need to check `no` if below doesn't match... 124 | if self.no.contains(dst.host()?) { 125 | return None; 126 | } 127 | 128 | match dst.scheme_str() { 129 | Some("http") => self.http.clone(), 130 | Some("https") => self.https.clone(), 131 | _ => None, 132 | } 133 | } 134 | } 135 | 136 | impl fmt::Debug for Matcher { 137 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 138 | let mut b = f.debug_struct("Matcher"); 139 | 140 | if let Some(ref http) = self.http { 141 | b.field("http", http); 142 | } 143 | 144 | if let Some(ref https) = self.https { 145 | b.field("https", https); 146 | } 147 | 148 | if !self.no.is_empty() { 149 | b.field("no", &self.no); 150 | } 151 | b.finish() 152 | } 153 | } 154 | 155 | // ===== impl Intercept ===== 156 | 157 | impl Intercept { 158 | /// Get the `http::Uri` for the target proxy. 159 | pub fn uri(&self) -> &http::Uri { 160 | &self.uri 161 | } 162 | 163 | /// Get any configured basic authorization. 164 | /// 165 | /// This should usually be used with a `Proxy-Authorization` header, to 166 | /// send in Basic format. 167 | /// 168 | /// # Example 169 | /// 170 | /// ```rust 171 | /// # use hyper_util::client::proxy::matcher::Matcher; 172 | /// # let uri = http::Uri::from_static("https://hyper.rs"); 173 | /// let m = Matcher::builder() 174 | /// .all("https://Aladdin:opensesame@localhost:8887") 175 | /// .build(); 176 | /// 177 | /// let proxy = m.intercept(&uri).expect("example"); 178 | /// let auth = proxy.basic_auth().expect("example"); 179 | /// assert_eq!(auth, "Basic QWxhZGRpbjpvcGVuc2VzYW1l"); 180 | /// ``` 181 | pub fn basic_auth(&self) -> Option<&HeaderValue> { 182 | if let Auth::Basic(ref val) = self.auth { 183 | Some(val) 184 | } else { 185 | None 186 | } 187 | } 188 | 189 | /// Get any configured raw authorization. 190 | /// 191 | /// If not detected as another scheme, this is the username and password 192 | /// that should be sent with whatever protocol the proxy handshake uses. 193 | /// 194 | /// # Example 195 | /// 196 | /// ```rust 197 | /// # use hyper_util::client::proxy::matcher::Matcher; 198 | /// # let uri = http::Uri::from_static("https://hyper.rs"); 199 | /// let m = Matcher::builder() 200 | /// .all("socks5h://Aladdin:opensesame@localhost:8887") 201 | /// .build(); 202 | /// 203 | /// let proxy = m.intercept(&uri).expect("example"); 204 | /// let auth = proxy.raw_auth().expect("example"); 205 | /// assert_eq!(auth, ("Aladdin", "opensesame")); 206 | /// ``` 207 | pub fn raw_auth(&self) -> Option<(&str, &str)> { 208 | if let Auth::Raw(ref u, ref p) = self.auth { 209 | Some((u.as_str(), p.as_str())) 210 | } else { 211 | None 212 | } 213 | } 214 | } 215 | 216 | impl fmt::Debug for Intercept { 217 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 218 | f.debug_struct("Intercept") 219 | .field("uri", &self.uri) 220 | // dont output auth, its sensitive 221 | .finish() 222 | } 223 | } 224 | 225 | // ===== impl Builder ===== 226 | 227 | impl Builder { 228 | fn from_env() -> Self { 229 | Builder { 230 | is_cgi: std::env::var_os("REQUEST_METHOD").is_some(), 231 | all: get_first_env(&["ALL_PROXY", "all_proxy"]), 232 | http: get_first_env(&["HTTP_PROXY", "http_proxy"]), 233 | https: get_first_env(&["HTTPS_PROXY", "https_proxy"]), 234 | no: get_first_env(&["NO_PROXY", "no_proxy"]), 235 | } 236 | } 237 | 238 | fn from_system() -> Self { 239 | #[allow(unused_mut)] 240 | let mut builder = Self::from_env(); 241 | 242 | #[cfg(all(feature = "client-proxy-system", target_os = "macos"))] 243 | mac::with_system(&mut builder); 244 | 245 | #[cfg(all(feature = "client-proxy-system", windows))] 246 | win::with_system(&mut builder); 247 | 248 | builder 249 | } 250 | 251 | /// Set the target proxy for all destinations. 252 | pub fn all(mut self, val: S) -> Self 253 | where 254 | S: IntoValue, 255 | { 256 | self.all = val.into_value(); 257 | self 258 | } 259 | 260 | /// Set the target proxy for HTTP destinations. 261 | pub fn http(mut self, val: S) -> Self 262 | where 263 | S: IntoValue, 264 | { 265 | self.http = val.into_value(); 266 | self 267 | } 268 | 269 | /// Set the target proxy for HTTPS destinations. 270 | pub fn https(mut self, val: S) -> Self 271 | where 272 | S: IntoValue, 273 | { 274 | self.https = val.into_value(); 275 | self 276 | } 277 | 278 | /// Set the "no" proxy filter. 279 | /// 280 | /// The rules are as follows: 281 | /// * Entries are expected to be comma-separated (whitespace between entries is ignored) 282 | /// * IP addresses (both IPv4 and IPv6) are allowed, as are optional subnet masks (by adding /size, 283 | /// for example "`192.168.1.0/24`"). 284 | /// * An entry "`*`" matches all hostnames (this is the only wildcard allowed) 285 | /// * Any other entry is considered a domain name (and may contain a leading dot, for example `google.com` 286 | /// and `.google.com` are equivalent) and would match both that domain AND all subdomains. 287 | /// 288 | /// For example, if `"NO_PROXY=google.com, 192.168.1.0/24"` was set, all of the following would match 289 | /// (and therefore would bypass the proxy): 290 | /// * `http://google.com/` 291 | /// * `http://www.google.com/` 292 | /// * `http://192.168.1.42/` 293 | /// 294 | /// The URL `http://notgoogle.com/` would not match. 295 | pub fn no(mut self, val: S) -> Self 296 | where 297 | S: IntoValue, 298 | { 299 | self.no = val.into_value(); 300 | self 301 | } 302 | 303 | /// Construct a [`Matcher`] using the configured values. 304 | pub fn build(self) -> Matcher { 305 | if self.is_cgi { 306 | return Matcher { 307 | http: None, 308 | https: None, 309 | no: NoProxy::empty(), 310 | }; 311 | } 312 | 313 | let all = parse_env_uri(&self.all); 314 | 315 | Matcher { 316 | http: parse_env_uri(&self.http).or_else(|| all.clone()), 317 | https: parse_env_uri(&self.https).or(all), 318 | no: NoProxy::from_string(&self.no), 319 | } 320 | } 321 | } 322 | 323 | fn get_first_env(names: &[&str]) -> String { 324 | for name in names { 325 | if let Ok(val) = std::env::var(name) { 326 | return val; 327 | } 328 | } 329 | 330 | String::new() 331 | } 332 | 333 | fn parse_env_uri(val: &str) -> Option { 334 | let uri = val.parse::().ok()?; 335 | let mut builder = http::Uri::builder(); 336 | let mut is_httpish = false; 337 | let mut auth = Auth::Empty; 338 | 339 | builder = builder.scheme(match uri.scheme() { 340 | Some(s) => { 341 | if s == &http::uri::Scheme::HTTP || s == &http::uri::Scheme::HTTPS { 342 | is_httpish = true; 343 | s.clone() 344 | } else if s.as_str() == "socks5" || s.as_str() == "socks5h" { 345 | s.clone() 346 | } else { 347 | // can't use this proxy scheme 348 | return None; 349 | } 350 | } 351 | // if no scheme provided, assume they meant 'http' 352 | None => { 353 | is_httpish = true; 354 | http::uri::Scheme::HTTP 355 | } 356 | }); 357 | 358 | let authority = uri.authority()?; 359 | 360 | if let Some((userinfo, host_port)) = authority.as_str().split_once('@') { 361 | let (user, pass) = userinfo.split_once(':')?; 362 | let user = percent_decode_str(user).decode_utf8_lossy(); 363 | let pass = percent_decode_str(pass).decode_utf8_lossy(); 364 | if is_httpish { 365 | auth = Auth::Basic(encode_basic_auth(&user, Some(&pass))); 366 | } else { 367 | auth = Auth::Raw(user.into(), pass.into()); 368 | } 369 | builder = builder.authority(host_port); 370 | } else { 371 | builder = builder.authority(authority.clone()); 372 | } 373 | 374 | // removing any path, but we MUST specify one or the builder errors 375 | builder = builder.path_and_query("/"); 376 | 377 | let dst = builder.build().ok()?; 378 | 379 | Some(Intercept { uri: dst, auth }) 380 | } 381 | 382 | fn encode_basic_auth(user: &str, pass: Option<&str>) -> HeaderValue { 383 | use base64::prelude::BASE64_STANDARD; 384 | use base64::write::EncoderWriter; 385 | use std::io::Write; 386 | 387 | let mut buf = b"Basic ".to_vec(); 388 | { 389 | let mut encoder = EncoderWriter::new(&mut buf, &BASE64_STANDARD); 390 | let _ = write!(encoder, "{user}:"); 391 | if let Some(password) = pass { 392 | let _ = write!(encoder, "{password}"); 393 | } 394 | } 395 | let mut header = HeaderValue::from_bytes(&buf).expect("base64 is always valid HeaderValue"); 396 | header.set_sensitive(true); 397 | header 398 | } 399 | 400 | impl NoProxy { 401 | /* 402 | fn from_env() -> NoProxy { 403 | let raw = std::env::var("NO_PROXY") 404 | .or_else(|_| std::env::var("no_proxy")) 405 | .unwrap_or_default(); 406 | 407 | Self::from_string(&raw) 408 | } 409 | */ 410 | 411 | fn empty() -> NoProxy { 412 | NoProxy { 413 | ips: IpMatcher(Vec::new()), 414 | domains: DomainMatcher(Vec::new()), 415 | } 416 | } 417 | 418 | /// Returns a new no-proxy configuration based on a `no_proxy` string (or `None` if no variables 419 | /// are set) 420 | /// The rules are as follows: 421 | /// * The environment variable `NO_PROXY` is checked, if it is not set, `no_proxy` is checked 422 | /// * If neither environment variable is set, `None` is returned 423 | /// * Entries are expected to be comma-separated (whitespace between entries is ignored) 424 | /// * IP addresses (both IPv4 and IPv6) are allowed, as are optional subnet masks (by adding /size, 425 | /// for example "`192.168.1.0/24`"). 426 | /// * An entry "`*`" matches all hostnames (this is the only wildcard allowed) 427 | /// * Any other entry is considered a domain name (and may contain a leading dot, for example `google.com` 428 | /// and `.google.com` are equivalent) and would match both that domain AND all subdomains. 429 | /// 430 | /// For example, if `"NO_PROXY=google.com, 192.168.1.0/24"` was set, all of the following would match 431 | /// (and therefore would bypass the proxy): 432 | /// * `http://google.com/` 433 | /// * `http://www.google.com/` 434 | /// * `http://192.168.1.42/` 435 | /// 436 | /// The URL `http://notgoogle.com/` would not match. 437 | pub fn from_string(no_proxy_list: &str) -> Self { 438 | let mut ips = Vec::new(); 439 | let mut domains = Vec::new(); 440 | let parts = no_proxy_list.split(',').map(str::trim); 441 | for part in parts { 442 | match part.parse::() { 443 | // If we can parse an IP net or address, then use it, otherwise, assume it is a domain 444 | Ok(ip) => ips.push(Ip::Network(ip)), 445 | Err(_) => match part.parse::() { 446 | Ok(addr) => ips.push(Ip::Address(addr)), 447 | Err(_) => { 448 | if !part.trim().is_empty() { 449 | domains.push(part.to_owned()) 450 | } 451 | } 452 | }, 453 | } 454 | } 455 | NoProxy { 456 | ips: IpMatcher(ips), 457 | domains: DomainMatcher(domains), 458 | } 459 | } 460 | 461 | /// Return true if this matches the host (domain or IP). 462 | pub fn contains(&self, host: &str) -> bool { 463 | // According to RFC3986, raw IPv6 hosts will be wrapped in []. So we need to strip those off 464 | // the end in order to parse correctly 465 | let host = if host.starts_with('[') { 466 | let x: &[_] = &['[', ']']; 467 | host.trim_matches(x) 468 | } else { 469 | host 470 | }; 471 | match host.parse::() { 472 | // If we can parse an IP addr, then use it, otherwise, assume it is a domain 473 | Ok(ip) => self.ips.contains(ip), 474 | Err(_) => self.domains.contains(host), 475 | } 476 | } 477 | 478 | fn is_empty(&self) -> bool { 479 | self.ips.0.is_empty() && self.domains.0.is_empty() 480 | } 481 | } 482 | 483 | impl IpMatcher { 484 | fn contains(&self, addr: IpAddr) -> bool { 485 | for ip in &self.0 { 486 | match ip { 487 | Ip::Address(address) => { 488 | if &addr == address { 489 | return true; 490 | } 491 | } 492 | Ip::Network(net) => { 493 | if net.contains(&addr) { 494 | return true; 495 | } 496 | } 497 | } 498 | } 499 | false 500 | } 501 | } 502 | 503 | impl DomainMatcher { 504 | // The following links may be useful to understand the origin of these rules: 505 | // * https://curl.se/libcurl/c/CURLOPT_NOPROXY.html 506 | // * https://github.com/curl/curl/issues/1208 507 | fn contains(&self, domain: &str) -> bool { 508 | let domain_len = domain.len(); 509 | for d in &self.0 { 510 | if d == domain || d.strip_prefix('.') == Some(domain) { 511 | return true; 512 | } else if domain.ends_with(d) { 513 | if d.starts_with('.') { 514 | // If the first character of d is a dot, that means the first character of domain 515 | // must also be a dot, so we are looking at a subdomain of d and that matches 516 | return true; 517 | } else if domain.as_bytes().get(domain_len - d.len() - 1) == Some(&b'.') { 518 | // Given that d is a prefix of domain, if the prior character in domain is a dot 519 | // then that means we must be matching a subdomain of d, and that matches 520 | return true; 521 | } 522 | } else if d == "*" { 523 | return true; 524 | } 525 | } 526 | false 527 | } 528 | } 529 | 530 | mod builder { 531 | /// A type that can used as a `Builder` value. 532 | /// 533 | /// Private and sealed, only visible in docs. 534 | pub trait IntoValue { 535 | #[doc(hidden)] 536 | fn into_value(self) -> String; 537 | } 538 | 539 | impl IntoValue for String { 540 | #[doc(hidden)] 541 | fn into_value(self) -> String { 542 | self 543 | } 544 | } 545 | 546 | impl IntoValue for &String { 547 | #[doc(hidden)] 548 | fn into_value(self) -> String { 549 | self.into() 550 | } 551 | } 552 | 553 | impl IntoValue for &str { 554 | #[doc(hidden)] 555 | fn into_value(self) -> String { 556 | self.into() 557 | } 558 | } 559 | } 560 | 561 | #[cfg(feature = "client-proxy-system")] 562 | #[cfg(target_os = "macos")] 563 | mod mac { 564 | use system_configuration::core_foundation::base::{CFType, TCFType, TCFTypeRef}; 565 | use system_configuration::core_foundation::dictionary::CFDictionary; 566 | use system_configuration::core_foundation::number::CFNumber; 567 | use system_configuration::core_foundation::string::{CFString, CFStringRef}; 568 | use system_configuration::dynamic_store::SCDynamicStoreBuilder; 569 | use system_configuration::sys::schema_definitions::{ 570 | kSCPropNetProxiesHTTPEnable, kSCPropNetProxiesHTTPPort, kSCPropNetProxiesHTTPProxy, 571 | kSCPropNetProxiesHTTPSEnable, kSCPropNetProxiesHTTPSPort, kSCPropNetProxiesHTTPSProxy, 572 | }; 573 | 574 | pub(super) fn with_system(builder: &mut super::Builder) { 575 | let store = SCDynamicStoreBuilder::new("").build(); 576 | 577 | let proxies_map = if let Some(proxies_map) = store.get_proxies() { 578 | proxies_map 579 | } else { 580 | return; 581 | }; 582 | 583 | if builder.http.is_empty() { 584 | let http_proxy_config = parse_setting_from_dynamic_store( 585 | &proxies_map, 586 | unsafe { kSCPropNetProxiesHTTPEnable }, 587 | unsafe { kSCPropNetProxiesHTTPProxy }, 588 | unsafe { kSCPropNetProxiesHTTPPort }, 589 | ); 590 | if let Some(http) = http_proxy_config { 591 | builder.http = http; 592 | } 593 | } 594 | 595 | if builder.https.is_empty() { 596 | let https_proxy_config = parse_setting_from_dynamic_store( 597 | &proxies_map, 598 | unsafe { kSCPropNetProxiesHTTPSEnable }, 599 | unsafe { kSCPropNetProxiesHTTPSProxy }, 600 | unsafe { kSCPropNetProxiesHTTPSPort }, 601 | ); 602 | 603 | if let Some(https) = https_proxy_config { 604 | builder.https = https; 605 | } 606 | } 607 | } 608 | 609 | fn parse_setting_from_dynamic_store( 610 | proxies_map: &CFDictionary, 611 | enabled_key: CFStringRef, 612 | host_key: CFStringRef, 613 | port_key: CFStringRef, 614 | ) -> Option { 615 | let proxy_enabled = proxies_map 616 | .find(enabled_key) 617 | .and_then(|flag| flag.downcast::()) 618 | .and_then(|flag| flag.to_i32()) 619 | .unwrap_or(0) 620 | == 1; 621 | 622 | if proxy_enabled { 623 | let proxy_host = proxies_map 624 | .find(host_key) 625 | .and_then(|host| host.downcast::()) 626 | .map(|host| host.to_string()); 627 | let proxy_port = proxies_map 628 | .find(port_key) 629 | .and_then(|port| port.downcast::()) 630 | .and_then(|port| port.to_i32()); 631 | 632 | return match (proxy_host, proxy_port) { 633 | (Some(proxy_host), Some(proxy_port)) => Some(format!("{proxy_host}:{proxy_port}")), 634 | (Some(proxy_host), None) => Some(proxy_host), 635 | (None, Some(_)) => None, 636 | (None, None) => None, 637 | }; 638 | } 639 | 640 | None 641 | } 642 | } 643 | 644 | #[cfg(feature = "client-proxy-system")] 645 | #[cfg(windows)] 646 | mod win { 647 | pub(super) fn with_system(builder: &mut super::Builder) { 648 | let settings = if let Ok(settings) = windows_registry::CURRENT_USER 649 | .open("Software\\Microsoft\\Windows\\CurrentVersion\\Internet Settings") 650 | { 651 | settings 652 | } else { 653 | return; 654 | }; 655 | 656 | if settings.get_u32("ProxyEnable").unwrap_or(0) == 0 { 657 | return; 658 | } 659 | 660 | if builder.http.is_empty() { 661 | if let Ok(val) = settings.get_string("ProxyServer") { 662 | builder.http = val; 663 | } 664 | } 665 | 666 | if builder.no.is_empty() { 667 | if let Ok(val) = settings.get_string("ProxyOverride") { 668 | builder.no = val 669 | .split(';') 670 | .map(|s| s.trim()) 671 | .collect::>() 672 | .join(",") 673 | .replace("*.", ""); 674 | } 675 | } 676 | } 677 | } 678 | 679 | #[cfg(test)] 680 | mod tests { 681 | use super::*; 682 | 683 | #[test] 684 | fn test_domain_matcher() { 685 | let domains = vec![".foo.bar".into(), "bar.foo".into()]; 686 | let matcher = DomainMatcher(domains); 687 | 688 | // domains match with leading `.` 689 | assert!(matcher.contains("foo.bar")); 690 | // subdomains match with leading `.` 691 | assert!(matcher.contains("www.foo.bar")); 692 | 693 | // domains match with no leading `.` 694 | assert!(matcher.contains("bar.foo")); 695 | // subdomains match with no leading `.` 696 | assert!(matcher.contains("www.bar.foo")); 697 | 698 | // non-subdomain string prefixes don't match 699 | assert!(!matcher.contains("notfoo.bar")); 700 | assert!(!matcher.contains("notbar.foo")); 701 | } 702 | 703 | #[test] 704 | fn test_no_proxy_wildcard() { 705 | let no_proxy = NoProxy::from_string("*"); 706 | assert!(no_proxy.contains("any.where")); 707 | } 708 | 709 | #[test] 710 | fn test_no_proxy_ip_ranges() { 711 | let no_proxy = 712 | NoProxy::from_string(".foo.bar, bar.baz,10.42.1.1/24,::1,10.124.7.8,2001::/17"); 713 | 714 | let should_not_match = [ 715 | // random url, not in no_proxy 716 | "hyper.rs", 717 | // make sure that random non-subdomain string prefixes don't match 718 | "notfoo.bar", 719 | // make sure that random non-subdomain string prefixes don't match 720 | "notbar.baz", 721 | // ipv4 address out of range 722 | "10.43.1.1", 723 | // ipv4 address out of range 724 | "10.124.7.7", 725 | // ipv6 address out of range 726 | "[ffff:db8:a0b:12f0::1]", 727 | // ipv6 address out of range 728 | "[2005:db8:a0b:12f0::1]", 729 | ]; 730 | 731 | for host in &should_not_match { 732 | assert!(!no_proxy.contains(host), "should not contain {host:?}"); 733 | } 734 | 735 | let should_match = [ 736 | // make sure subdomains (with leading .) match 737 | "hello.foo.bar", 738 | // make sure exact matches (without leading .) match (also makes sure spaces between entries work) 739 | "bar.baz", 740 | // make sure subdomains (without leading . in no_proxy) match 741 | "foo.bar.baz", 742 | // make sure subdomains (without leading . in no_proxy) match - this differs from cURL 743 | "foo.bar", 744 | // ipv4 address match within range 745 | "10.42.1.100", 746 | // ipv6 address exact match 747 | "[::1]", 748 | // ipv6 address match within range 749 | "[2001:db8:a0b:12f0::1]", 750 | // ipv4 address exact match 751 | "10.124.7.8", 752 | ]; 753 | 754 | for host in &should_match { 755 | assert!(no_proxy.contains(host), "should contain {host:?}"); 756 | } 757 | } 758 | 759 | macro_rules! p { 760 | ($($n:ident = $v:expr,)*) => ({Builder { 761 | $($n: $v.into(),)* 762 | ..Builder::default() 763 | }.build()}); 764 | } 765 | 766 | fn intercept(p: &Matcher, u: &str) -> Intercept { 767 | p.intercept(&u.parse().unwrap()).unwrap() 768 | } 769 | 770 | #[test] 771 | fn test_all_proxy() { 772 | let p = p! { 773 | all = "http://om.nom", 774 | }; 775 | 776 | assert_eq!("http://om.nom", intercept(&p, "http://example.com").uri()); 777 | 778 | assert_eq!("http://om.nom", intercept(&p, "https://example.com").uri()); 779 | } 780 | 781 | #[test] 782 | fn test_specific_overrides_all() { 783 | let p = p! { 784 | all = "http://no.pe", 785 | http = "http://y.ep", 786 | }; 787 | 788 | assert_eq!("http://no.pe", intercept(&p, "https://example.com").uri()); 789 | 790 | // the http rule is "more specific" than the all rule 791 | assert_eq!("http://y.ep", intercept(&p, "http://example.com").uri()); 792 | } 793 | 794 | #[test] 795 | fn test_parse_no_scheme_defaults_to_http() { 796 | let p = p! { 797 | https = "y.ep", 798 | http = "127.0.0.1:8887", 799 | }; 800 | 801 | assert_eq!(intercept(&p, "https://example.local").uri(), "http://y.ep"); 802 | assert_eq!( 803 | intercept(&p, "http://example.local").uri(), 804 | "http://127.0.0.1:8887" 805 | ); 806 | } 807 | 808 | #[test] 809 | fn test_parse_http_auth() { 810 | let p = p! { 811 | all = "http://Aladdin:opensesame@y.ep", 812 | }; 813 | 814 | let proxy = intercept(&p, "https://example.local"); 815 | assert_eq!(proxy.uri(), "http://y.ep"); 816 | assert_eq!( 817 | proxy.basic_auth().expect("basic_auth"), 818 | "Basic QWxhZGRpbjpvcGVuc2VzYW1l" 819 | ); 820 | } 821 | 822 | #[test] 823 | fn test_parse_http_auth_without_scheme() { 824 | let p = p! { 825 | all = "Aladdin:opensesame@y.ep", 826 | }; 827 | 828 | let proxy = intercept(&p, "https://example.local"); 829 | assert_eq!(proxy.uri(), "http://y.ep"); 830 | assert_eq!( 831 | proxy.basic_auth().expect("basic_auth"), 832 | "Basic QWxhZGRpbjpvcGVuc2VzYW1l" 833 | ); 834 | } 835 | 836 | #[test] 837 | fn test_dont_parse_http_when_is_cgi() { 838 | let mut builder = Matcher::builder(); 839 | builder.is_cgi = true; 840 | builder.http = "http://never.gonna.let.you.go".into(); 841 | let m = builder.build(); 842 | 843 | assert!(m.intercept(&"http://rick.roll".parse().unwrap()).is_none()); 844 | } 845 | } 846 | -------------------------------------------------------------------------------- /src/client/proxy/mod.rs: -------------------------------------------------------------------------------- 1 | //! Proxy utilities 2 | 3 | pub mod matcher; 4 | -------------------------------------------------------------------------------- /src/client/service.rs: -------------------------------------------------------------------------------- 1 | struct ConnectingPool { 2 | connector: C, 3 | pool: P, 4 | } 5 | 6 | struct PoolableSvc(S); 7 | 8 | 9 | -------------------------------------------------------------------------------- /src/common/exec.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | 3 | use hyper::rt::Executor; 4 | use std::fmt; 5 | use std::future::Future; 6 | use std::pin::Pin; 7 | use std::sync::Arc; 8 | 9 | pub(crate) type BoxSendFuture = Pin + Send>>; 10 | 11 | // Either the user provides an executor for background tasks, or we use 12 | // `tokio::spawn`. 13 | #[derive(Clone)] 14 | pub(crate) enum Exec { 15 | Executor(Arc + Send + Sync>), 16 | } 17 | 18 | // ===== impl Exec ===== 19 | 20 | impl Exec { 21 | pub(crate) fn new(inner: E) -> Self 22 | where 23 | E: Executor + Send + Sync + 'static, 24 | { 25 | Exec::Executor(Arc::new(inner)) 26 | } 27 | 28 | pub(crate) fn execute(&self, fut: F) 29 | where 30 | F: Future + Send + 'static, 31 | { 32 | match *self { 33 | Exec::Executor(ref e) => { 34 | e.execute(Box::pin(fut)); 35 | } 36 | } 37 | } 38 | } 39 | 40 | impl fmt::Debug for Exec { 41 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 42 | f.debug_struct("Exec").finish() 43 | } 44 | } 45 | 46 | impl hyper::rt::Executor for Exec 47 | where 48 | F: Future + Send + 'static, 49 | { 50 | fn execute(&self, fut: F) { 51 | Exec::execute(self, fut); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/common/future.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | future::Future, 3 | pin::Pin, 4 | task::{Context, Poll}, 5 | }; 6 | 7 | // TODO: replace with `std::future::poll_fn` once MSRV >= 1.64 8 | pub(crate) fn poll_fn(f: F) -> PollFn 9 | where 10 | F: FnMut(&mut Context<'_>) -> Poll, 11 | { 12 | PollFn { f } 13 | } 14 | 15 | pub(crate) struct PollFn { 16 | f: F, 17 | } 18 | 19 | impl Unpin for PollFn {} 20 | 21 | impl Future for PollFn 22 | where 23 | F: FnMut(&mut Context<'_>) -> Poll, 24 | { 25 | type Output = T; 26 | 27 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 28 | (self.f)(cx) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/common/lazy.rs: -------------------------------------------------------------------------------- 1 | use pin_project_lite::pin_project; 2 | 3 | use std::future::Future; 4 | use std::pin::Pin; 5 | use std::task::{self, Poll}; 6 | 7 | pub(crate) trait Started: Future { 8 | fn started(&self) -> bool; 9 | } 10 | 11 | pub(crate) fn lazy(func: F) -> Lazy 12 | where 13 | F: FnOnce() -> R, 14 | R: Future + Unpin, 15 | { 16 | Lazy { 17 | inner: Inner::Init { func }, 18 | } 19 | } 20 | 21 | // FIXME: allow() required due to `impl Trait` leaking types to this lint 22 | pin_project! { 23 | #[allow(missing_debug_implementations)] 24 | pub(crate) struct Lazy { 25 | #[pin] 26 | inner: Inner, 27 | } 28 | } 29 | 30 | pin_project! { 31 | #[project = InnerProj] 32 | #[project_replace = InnerProjReplace] 33 | enum Inner { 34 | Init { func: F }, 35 | Fut { #[pin] fut: R }, 36 | Empty, 37 | } 38 | } 39 | 40 | impl Started for Lazy 41 | where 42 | F: FnOnce() -> R, 43 | R: Future, 44 | { 45 | fn started(&self) -> bool { 46 | match self.inner { 47 | Inner::Init { .. } => false, 48 | Inner::Fut { .. } | Inner::Empty => true, 49 | } 50 | } 51 | } 52 | 53 | impl Future for Lazy 54 | where 55 | F: FnOnce() -> R, 56 | R: Future, 57 | { 58 | type Output = R::Output; 59 | 60 | fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { 61 | let mut this = self.project(); 62 | 63 | if let InnerProj::Fut { fut } = this.inner.as_mut().project() { 64 | return fut.poll(cx); 65 | } 66 | 67 | match this.inner.as_mut().project_replace(Inner::Empty) { 68 | InnerProjReplace::Init { func } => { 69 | this.inner.set(Inner::Fut { fut: func() }); 70 | if let InnerProj::Fut { fut } = this.inner.project() { 71 | return fut.poll(cx); 72 | } 73 | unreachable!() 74 | } 75 | _ => unreachable!("lazy state wrong"), 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/common/mod.rs: -------------------------------------------------------------------------------- 1 | #![allow(missing_docs)] 2 | 3 | pub(crate) mod exec; 4 | #[cfg(feature = "client")] 5 | mod lazy; 6 | pub(crate) mod rewind; 7 | #[cfg(feature = "client")] 8 | mod sync; 9 | pub(crate) mod timer; 10 | 11 | #[cfg(feature = "client")] 12 | pub(crate) use exec::Exec; 13 | 14 | #[cfg(feature = "client")] 15 | pub(crate) use lazy::{lazy, Started as Lazy}; 16 | #[cfg(feature = "client")] 17 | pub(crate) use sync::SyncWrapper; 18 | 19 | pub(crate) mod future; 20 | -------------------------------------------------------------------------------- /src/common/rewind.rs: -------------------------------------------------------------------------------- 1 | use std::{cmp, io}; 2 | 3 | use bytes::{Buf, Bytes}; 4 | use hyper::rt::{Read, ReadBufCursor, Write}; 5 | 6 | use std::{ 7 | pin::Pin, 8 | task::{self, Poll}, 9 | }; 10 | 11 | /// Combine a buffer with an IO, rewinding reads to use the buffer. 12 | #[derive(Debug)] 13 | pub(crate) struct Rewind { 14 | pub(crate) pre: Option, 15 | pub(crate) inner: T, 16 | } 17 | 18 | impl Rewind { 19 | #[cfg(all(feature = "server", any(feature = "http1", feature = "http2")))] 20 | pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self { 21 | Rewind { 22 | pre: Some(buf), 23 | inner: io, 24 | } 25 | } 26 | } 27 | 28 | impl Read for Rewind 29 | where 30 | T: Read + Unpin, 31 | { 32 | fn poll_read( 33 | mut self: Pin<&mut Self>, 34 | cx: &mut task::Context<'_>, 35 | mut buf: ReadBufCursor<'_>, 36 | ) -> Poll> { 37 | if let Some(mut prefix) = self.pre.take() { 38 | // If there are no remaining bytes, let the bytes get dropped. 39 | if !prefix.is_empty() { 40 | let copy_len = cmp::min(prefix.len(), buf.remaining()); 41 | buf.put_slice(&prefix[..copy_len]); 42 | prefix.advance(copy_len); 43 | // Put back what's left 44 | if !prefix.is_empty() { 45 | self.pre = Some(prefix); 46 | } 47 | 48 | return Poll::Ready(Ok(())); 49 | } 50 | } 51 | Pin::new(&mut self.inner).poll_read(cx, buf) 52 | } 53 | } 54 | 55 | impl Write for Rewind 56 | where 57 | T: Write + Unpin, 58 | { 59 | fn poll_write( 60 | mut self: Pin<&mut Self>, 61 | cx: &mut task::Context<'_>, 62 | buf: &[u8], 63 | ) -> Poll> { 64 | Pin::new(&mut self.inner).poll_write(cx, buf) 65 | } 66 | 67 | fn poll_write_vectored( 68 | mut self: Pin<&mut Self>, 69 | cx: &mut task::Context<'_>, 70 | bufs: &[io::IoSlice<'_>], 71 | ) -> Poll> { 72 | Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) 73 | } 74 | 75 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { 76 | Pin::new(&mut self.inner).poll_flush(cx) 77 | } 78 | 79 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { 80 | Pin::new(&mut self.inner).poll_shutdown(cx) 81 | } 82 | 83 | fn is_write_vectored(&self) -> bool { 84 | self.inner.is_write_vectored() 85 | } 86 | } 87 | 88 | /* 89 | #[cfg(test)] 90 | mod tests { 91 | use super::Rewind; 92 | use bytes::Bytes; 93 | use tokio::io::AsyncReadExt; 94 | 95 | #[cfg(not(miri))] 96 | #[tokio::test] 97 | async fn partial_rewind() { 98 | let underlying = [104, 101, 108, 108, 111]; 99 | 100 | let mock = tokio_test::io::Builder::new().read(&underlying).build(); 101 | 102 | let mut stream = Rewind::new(mock); 103 | 104 | // Read off some bytes, ensure we filled o1 105 | let mut buf = [0; 2]; 106 | stream.read_exact(&mut buf).await.expect("read1"); 107 | 108 | // Rewind the stream so that it is as if we never read in the first place. 109 | stream.rewind(Bytes::copy_from_slice(&buf[..])); 110 | 111 | let mut buf = [0; 5]; 112 | stream.read_exact(&mut buf).await.expect("read1"); 113 | 114 | // At this point we should have read everything that was in the MockStream 115 | assert_eq!(&buf, &underlying); 116 | } 117 | 118 | #[cfg(not(miri))] 119 | #[tokio::test] 120 | async fn full_rewind() { 121 | let underlying = [104, 101, 108, 108, 111]; 122 | 123 | let mock = tokio_test::io::Builder::new().read(&underlying).build(); 124 | 125 | let mut stream = Rewind::new(mock); 126 | 127 | let mut buf = [0; 5]; 128 | stream.read_exact(&mut buf).await.expect("read1"); 129 | 130 | // Rewind the stream so that it is as if we never read in the first place. 131 | stream.rewind(Bytes::copy_from_slice(&buf[..])); 132 | 133 | let mut buf = [0; 5]; 134 | stream.read_exact(&mut buf).await.expect("read1"); 135 | } 136 | } 137 | */ 138 | -------------------------------------------------------------------------------- /src/common/sync.rs: -------------------------------------------------------------------------------- 1 | pub(crate) struct SyncWrapper(T); 2 | 3 | impl SyncWrapper { 4 | /// Creates a new SyncWrapper containing the given value. 5 | /// 6 | /// # Examples 7 | /// 8 | /// ```ignore 9 | /// use hyper::common::sync_wrapper::SyncWrapper; 10 | /// 11 | /// let wrapped = SyncWrapper::new(42); 12 | /// ``` 13 | pub(crate) fn new(value: T) -> Self { 14 | Self(value) 15 | } 16 | 17 | /// Acquires a reference to the protected value. 18 | /// 19 | /// This is safe because it requires an exclusive reference to the wrapper. Therefore this method 20 | /// neither panics nor does it return an error. This is in contrast to [`Mutex::get_mut`] which 21 | /// returns an error if another thread panicked while holding the lock. It is not recommended 22 | /// to send an exclusive reference to a potentially damaged value to another thread for further 23 | /// processing. 24 | /// 25 | /// [`Mutex::get_mut`]: https://doc.rust-lang.org/std/sync/struct.Mutex.html#method.get_mut 26 | /// 27 | /// # Examples 28 | /// 29 | /// ```ignore 30 | /// use hyper::common::sync_wrapper::SyncWrapper; 31 | /// 32 | /// let mut wrapped = SyncWrapper::new(42); 33 | /// let value = wrapped.get_mut(); 34 | /// *value = 0; 35 | /// assert_eq!(*wrapped.get_mut(), 0); 36 | /// ``` 37 | pub(crate) fn get_mut(&mut self) -> &mut T { 38 | &mut self.0 39 | } 40 | 41 | /// Consumes this wrapper, returning the underlying data. 42 | /// 43 | /// This is safe because it requires ownership of the wrapper, aherefore this method will neither 44 | /// panic nor does it return an error. This is in contrast to [`Mutex::into_inner`] which 45 | /// returns an error if another thread panicked while holding the lock. It is not recommended 46 | /// to send an exclusive reference to a potentially damaged value to another thread for further 47 | /// processing. 48 | /// 49 | /// [`Mutex::into_inner`]: https://doc.rust-lang.org/std/sync/struct.Mutex.html#method.into_inner 50 | /// 51 | /// # Examples 52 | /// 53 | /// ```ignore 54 | /// use hyper::common::sync_wrapper::SyncWrapper; 55 | /// 56 | /// let mut wrapped = SyncWrapper::new(42); 57 | /// assert_eq!(wrapped.into_inner(), 42); 58 | /// ``` 59 | #[allow(dead_code)] 60 | pub(crate) fn into_inner(self) -> T { 61 | self.0 62 | } 63 | } 64 | 65 | // this is safe because the only operations permitted on this data structure require exclusive 66 | // access or ownership 67 | unsafe impl Sync for SyncWrapper {} 68 | -------------------------------------------------------------------------------- /src/common/timer.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | 3 | use std::fmt; 4 | use std::pin::Pin; 5 | use std::sync::Arc; 6 | use std::time::Duration; 7 | use std::time::Instant; 8 | 9 | use hyper::rt::Sleep; 10 | 11 | #[derive(Clone)] 12 | pub(crate) struct Timer(Arc); 13 | 14 | // =====impl Timer===== 15 | impl Timer { 16 | pub(crate) fn new(inner: T) -> Self 17 | where 18 | T: hyper::rt::Timer + Send + Sync + 'static, 19 | { 20 | Self(Arc::new(inner)) 21 | } 22 | } 23 | 24 | impl fmt::Debug for Timer { 25 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 26 | f.debug_struct("Timer").finish() 27 | } 28 | } 29 | 30 | impl hyper::rt::Timer for Timer { 31 | fn sleep(&self, duration: Duration) -> Pin> { 32 | self.0.sleep(duration) 33 | } 34 | 35 | fn sleep_until(&self, deadline: Instant) -> Pin> { 36 | self.0.sleep_until(deadline) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | /* 2 | use std::error::Error; 3 | 4 | pub(crate) fn find<'a, E: Error + 'static>(top: &'a (dyn Error + 'static)) -> Option<&'a E> { 5 | let mut err = Some(top); 6 | while let Some(src) = err { 7 | if src.is::() { 8 | return src.downcast_ref(); 9 | } 10 | err = src.source(); 11 | } 12 | None 13 | } 14 | */ 15 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(missing_docs)] 2 | #![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))] 3 | 4 | //! Utilities for working with hyper. 5 | //! 6 | //! This crate is less-stable than [`hyper`](https://docs.rs/hyper). However, 7 | //! does respect Rust's semantic version regarding breaking changes. 8 | 9 | #[cfg(feature = "client")] 10 | pub mod client; 11 | mod common; 12 | pub mod rt; 13 | #[cfg(feature = "server")] 14 | pub mod server; 15 | #[cfg(any(feature = "service", feature = "client-legacy"))] 16 | pub mod service; 17 | 18 | mod error; 19 | -------------------------------------------------------------------------------- /src/rt/io.rs: -------------------------------------------------------------------------------- 1 | use std::marker::Unpin; 2 | use std::pin::Pin; 3 | use std::task::Poll; 4 | 5 | use futures_core::ready; 6 | use hyper::rt::{Read, ReadBuf, Write}; 7 | 8 | use crate::common::future::poll_fn; 9 | 10 | pub(crate) async fn read(io: &mut T, buf: &mut [u8]) -> Result 11 | where 12 | T: Read + Unpin, 13 | { 14 | poll_fn(move |cx| { 15 | let mut buf = ReadBuf::new(buf); 16 | ready!(Pin::new(&mut *io).poll_read(cx, buf.unfilled()))?; 17 | Poll::Ready(Ok(buf.filled().len())) 18 | }) 19 | .await 20 | } 21 | 22 | pub(crate) async fn write_all(io: &mut T, buf: &[u8]) -> Result<(), std::io::Error> 23 | where 24 | T: Write + Unpin, 25 | { 26 | let mut n = 0; 27 | poll_fn(move |cx| { 28 | while n < buf.len() { 29 | n += ready!(Pin::new(&mut *io).poll_write(cx, &buf[n..])?); 30 | } 31 | Poll::Ready(Ok(())) 32 | }) 33 | .await 34 | } 35 | -------------------------------------------------------------------------------- /src/rt/mod.rs: -------------------------------------------------------------------------------- 1 | //! Runtime utilities 2 | 3 | #[cfg(feature = "client-legacy")] 4 | mod io; 5 | #[cfg(feature = "client-legacy")] 6 | pub(crate) use self::io::{read, write_all}; 7 | 8 | #[cfg(feature = "tokio")] 9 | pub mod tokio; 10 | 11 | #[cfg(feature = "tokio")] 12 | pub use self::tokio::{TokioExecutor, TokioIo, TokioTimer}; 13 | -------------------------------------------------------------------------------- /src/rt/tokio.rs: -------------------------------------------------------------------------------- 1 | //! [`tokio`] runtime components integration for [`hyper`]. 2 | //! 3 | //! [`hyper::rt`] exposes a set of traits to allow hyper to be agnostic to 4 | //! its underlying asynchronous runtime. This submodule provides glue for 5 | //! [`tokio`] users to bridge those types to [`hyper`]'s interfaces. 6 | //! 7 | //! # IO 8 | //! 9 | //! [`hyper`] abstracts over asynchronous readers and writers using [`Read`] 10 | //! and [`Write`], while [`tokio`] abstracts over this using [`AsyncRead`] 11 | //! and [`AsyncWrite`]. This submodule provides a collection of IO adaptors 12 | //! to bridge these two IO ecosystems together: [`TokioIo`], 13 | //! [`WithHyperIo`], and [`WithTokioIo`]. 14 | //! 15 | //! To compare and constrast these IO adaptors and to help explain which 16 | //! is the proper choice for your needs, here is a table showing which IO 17 | //! traits these implement, given two types `T` and `H` which implement 18 | //! Tokio's and Hyper's corresponding IO traits: 19 | //! 20 | //! | | [`AsyncRead`] | [`AsyncWrite`] | [`Read`] | [`Write`] | 21 | //! |--------------------|------------------|-------------------|--------------|--------------| 22 | //! | `T` | ✅ **true** | ✅ **true** | ❌ **false** | ❌ **false** | 23 | //! | `H` | ❌ **false** | ❌ **false** | ✅ **true** | ✅ **true** | 24 | //! | [`TokioIo`] | ❌ **false** | ❌ **false** | ✅ **true** | ✅ **true** | 25 | //! | [`TokioIo`] | ✅ **true** | ✅ **true** | ❌ **false** | ❌ **false** | 26 | //! | [`WithHyperIo`] | ✅ **true** | ✅ **true** | ✅ **true** | ✅ **true** | 27 | //! | [`WithHyperIo`] | ❌ **false** | ❌ **false** | ❌ **false** | ❌ **false** | 28 | //! | [`WithTokioIo`] | ❌ **false** | ❌ **false** | ❌ **false** | ❌ **false** | 29 | //! | [`WithTokioIo`] | ✅ **true** | ✅ **true** | ✅ **true** | ✅ **true** | 30 | //! 31 | //! For most situations, [`TokioIo`] is the proper choice. This should be 32 | //! constructed, wrapping some underlying [`hyper`] or [`tokio`] IO, at the 33 | //! call-site of a function like [`hyper::client::conn::http1::handshake`]. 34 | //! 35 | //! [`TokioIo`] switches across these ecosystems, but notably does not 36 | //! preserve the existing IO trait implementations of its underlying IO. If 37 | //! one wishes to _extend_ IO with additional implementations, 38 | //! [`WithHyperIo`] and [`WithTokioIo`] are the correct choice. 39 | //! 40 | //! For example, a Tokio reader/writer can be wrapped in [`WithHyperIo`]. 41 | //! That will implement _both_ sets of IO traits. Conversely, 42 | //! [`WithTokioIo`] will implement both sets of IO traits given a 43 | //! reader/writer that implements Hyper's [`Read`] and [`Write`]. 44 | //! 45 | //! See [`tokio::io`] and ["_Asynchronous IO_"][tokio-async-docs] for more 46 | //! information. 47 | //! 48 | //! [`AsyncRead`]: tokio::io::AsyncRead 49 | //! [`AsyncWrite`]: tokio::io::AsyncWrite 50 | //! [`Read`]: hyper::rt::Read 51 | //! [`Write`]: hyper::rt::Write 52 | //! [tokio-async-docs]: https://docs.rs/tokio/latest/tokio/#asynchronous-io 53 | 54 | use std::{ 55 | future::Future, 56 | pin::Pin, 57 | task::{Context, Poll}, 58 | time::{Duration, Instant}, 59 | }; 60 | 61 | use hyper::rt::{Executor, Sleep, Timer}; 62 | use pin_project_lite::pin_project; 63 | 64 | #[cfg(feature = "tracing")] 65 | use tracing::instrument::Instrument; 66 | 67 | pub use self::{with_hyper_io::WithHyperIo, with_tokio_io::WithTokioIo}; 68 | 69 | mod with_hyper_io; 70 | mod with_tokio_io; 71 | 72 | /// Future executor that utilises `tokio` threads. 73 | #[non_exhaustive] 74 | #[derive(Default, Debug, Clone)] 75 | pub struct TokioExecutor {} 76 | 77 | pin_project! { 78 | /// A wrapper that implements Tokio's IO traits for an inner type that 79 | /// implements hyper's IO traits, or vice versa (implements hyper's IO 80 | /// traits for a type that implements Tokio's IO traits). 81 | #[derive(Debug)] 82 | pub struct TokioIo { 83 | #[pin] 84 | inner: T, 85 | } 86 | } 87 | 88 | /// A Timer that uses the tokio runtime. 89 | #[non_exhaustive] 90 | #[derive(Default, Clone, Debug)] 91 | pub struct TokioTimer; 92 | 93 | // Use TokioSleep to get tokio::time::Sleep to implement Unpin. 94 | // see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html 95 | pin_project! { 96 | #[derive(Debug)] 97 | struct TokioSleep { 98 | #[pin] 99 | inner: tokio::time::Sleep, 100 | } 101 | } 102 | 103 | // ===== impl TokioExecutor ===== 104 | 105 | impl Executor for TokioExecutor 106 | where 107 | Fut: Future + Send + 'static, 108 | Fut::Output: Send + 'static, 109 | { 110 | fn execute(&self, fut: Fut) { 111 | #[cfg(feature = "tracing")] 112 | tokio::spawn(fut.in_current_span()); 113 | 114 | #[cfg(not(feature = "tracing"))] 115 | tokio::spawn(fut); 116 | } 117 | } 118 | 119 | impl TokioExecutor { 120 | /// Create new executor that relies on [`tokio::spawn`] to execute futures. 121 | pub fn new() -> Self { 122 | Self {} 123 | } 124 | } 125 | 126 | // ==== impl TokioIo ===== 127 | 128 | impl TokioIo { 129 | /// Wrap a type implementing Tokio's or hyper's IO traits. 130 | pub fn new(inner: T) -> Self { 131 | Self { inner } 132 | } 133 | 134 | /// Borrow the inner type. 135 | pub fn inner(&self) -> &T { 136 | &self.inner 137 | } 138 | 139 | /// Mut borrow the inner type. 140 | pub fn inner_mut(&mut self) -> &mut T { 141 | &mut self.inner 142 | } 143 | 144 | /// Consume this wrapper and get the inner type. 145 | pub fn into_inner(self) -> T { 146 | self.inner 147 | } 148 | } 149 | 150 | impl hyper::rt::Read for TokioIo 151 | where 152 | T: tokio::io::AsyncRead, 153 | { 154 | fn poll_read( 155 | self: Pin<&mut Self>, 156 | cx: &mut Context<'_>, 157 | mut buf: hyper::rt::ReadBufCursor<'_>, 158 | ) -> Poll> { 159 | let n = unsafe { 160 | let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); 161 | match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { 162 | Poll::Ready(Ok(())) => tbuf.filled().len(), 163 | other => return other, 164 | } 165 | }; 166 | 167 | unsafe { 168 | buf.advance(n); 169 | } 170 | Poll::Ready(Ok(())) 171 | } 172 | } 173 | 174 | impl hyper::rt::Write for TokioIo 175 | where 176 | T: tokio::io::AsyncWrite, 177 | { 178 | fn poll_write( 179 | self: Pin<&mut Self>, 180 | cx: &mut Context<'_>, 181 | buf: &[u8], 182 | ) -> Poll> { 183 | tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) 184 | } 185 | 186 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 187 | tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) 188 | } 189 | 190 | fn poll_shutdown( 191 | self: Pin<&mut Self>, 192 | cx: &mut Context<'_>, 193 | ) -> Poll> { 194 | tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) 195 | } 196 | 197 | fn is_write_vectored(&self) -> bool { 198 | tokio::io::AsyncWrite::is_write_vectored(&self.inner) 199 | } 200 | 201 | fn poll_write_vectored( 202 | self: Pin<&mut Self>, 203 | cx: &mut Context<'_>, 204 | bufs: &[std::io::IoSlice<'_>], 205 | ) -> Poll> { 206 | tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) 207 | } 208 | } 209 | 210 | impl tokio::io::AsyncRead for TokioIo 211 | where 212 | T: hyper::rt::Read, 213 | { 214 | fn poll_read( 215 | self: Pin<&mut Self>, 216 | cx: &mut Context<'_>, 217 | tbuf: &mut tokio::io::ReadBuf<'_>, 218 | ) -> Poll> { 219 | //let init = tbuf.initialized().len(); 220 | let filled = tbuf.filled().len(); 221 | let sub_filled = unsafe { 222 | let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); 223 | 224 | match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { 225 | Poll::Ready(Ok(())) => buf.filled().len(), 226 | other => return other, 227 | } 228 | }; 229 | 230 | let n_filled = filled + sub_filled; 231 | // At least sub_filled bytes had to have been initialized. 232 | let n_init = sub_filled; 233 | unsafe { 234 | tbuf.assume_init(n_init); 235 | tbuf.set_filled(n_filled); 236 | } 237 | 238 | Poll::Ready(Ok(())) 239 | } 240 | } 241 | 242 | impl tokio::io::AsyncWrite for TokioIo 243 | where 244 | T: hyper::rt::Write, 245 | { 246 | fn poll_write( 247 | self: Pin<&mut Self>, 248 | cx: &mut Context<'_>, 249 | buf: &[u8], 250 | ) -> Poll> { 251 | hyper::rt::Write::poll_write(self.project().inner, cx, buf) 252 | } 253 | 254 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 255 | hyper::rt::Write::poll_flush(self.project().inner, cx) 256 | } 257 | 258 | fn poll_shutdown( 259 | self: Pin<&mut Self>, 260 | cx: &mut Context<'_>, 261 | ) -> Poll> { 262 | hyper::rt::Write::poll_shutdown(self.project().inner, cx) 263 | } 264 | 265 | fn is_write_vectored(&self) -> bool { 266 | hyper::rt::Write::is_write_vectored(&self.inner) 267 | } 268 | 269 | fn poll_write_vectored( 270 | self: Pin<&mut Self>, 271 | cx: &mut Context<'_>, 272 | bufs: &[std::io::IoSlice<'_>], 273 | ) -> Poll> { 274 | hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) 275 | } 276 | } 277 | 278 | // ==== impl TokioTimer ===== 279 | 280 | impl Timer for TokioTimer { 281 | fn sleep(&self, duration: Duration) -> Pin> { 282 | Box::pin(TokioSleep { 283 | inner: tokio::time::sleep(duration), 284 | }) 285 | } 286 | 287 | fn sleep_until(&self, deadline: Instant) -> Pin> { 288 | Box::pin(TokioSleep { 289 | inner: tokio::time::sleep_until(deadline.into()), 290 | }) 291 | } 292 | 293 | fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { 294 | if let Some(sleep) = sleep.as_mut().downcast_mut_pin::() { 295 | sleep.reset(new_deadline) 296 | } 297 | } 298 | } 299 | 300 | impl TokioTimer { 301 | /// Create a new TokioTimer 302 | pub fn new() -> Self { 303 | Self {} 304 | } 305 | } 306 | 307 | impl Future for TokioSleep { 308 | type Output = (); 309 | 310 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 311 | self.project().inner.poll(cx) 312 | } 313 | } 314 | 315 | impl Sleep for TokioSleep {} 316 | 317 | impl TokioSleep { 318 | fn reset(self: Pin<&mut Self>, deadline: Instant) { 319 | self.project().inner.as_mut().reset(deadline.into()); 320 | } 321 | } 322 | 323 | #[cfg(test)] 324 | mod tests { 325 | use crate::rt::TokioExecutor; 326 | use hyper::rt::Executor; 327 | use tokio::sync::oneshot; 328 | 329 | #[cfg(not(miri))] 330 | #[tokio::test] 331 | async fn simple_execute() -> Result<(), Box> { 332 | let (tx, rx) = oneshot::channel(); 333 | let executor = TokioExecutor::new(); 334 | executor.execute(async move { 335 | tx.send(()).unwrap(); 336 | }); 337 | rx.await.map_err(Into::into) 338 | } 339 | } 340 | -------------------------------------------------------------------------------- /src/rt/tokio/with_hyper_io.rs: -------------------------------------------------------------------------------- 1 | use pin_project_lite::pin_project; 2 | use std::{ 3 | pin::Pin, 4 | task::{Context, Poll}, 5 | }; 6 | 7 | pin_project! { 8 | /// Extends an underlying [`tokio`] I/O with [`hyper`] I/O implementations. 9 | /// 10 | /// This implements [`Read`] and [`Write`] given an inner type that implements [`AsyncRead`] 11 | /// and [`AsyncWrite`], respectively. 12 | #[derive(Debug)] 13 | pub struct WithHyperIo { 14 | #[pin] 15 | inner: I, 16 | } 17 | } 18 | 19 | // ==== impl WithHyperIo ===== 20 | 21 | impl WithHyperIo { 22 | /// Wraps the inner I/O in an [`WithHyperIo`] 23 | pub fn new(inner: I) -> Self { 24 | Self { inner } 25 | } 26 | 27 | /// Returns a reference to the inner type. 28 | pub fn inner(&self) -> &I { 29 | &self.inner 30 | } 31 | 32 | /// Returns a mutable reference to the inner type. 33 | pub fn inner_mut(&mut self) -> &mut I { 34 | &mut self.inner 35 | } 36 | 37 | /// Consumes this wrapper and returns the inner type. 38 | pub fn into_inner(self) -> I { 39 | self.inner 40 | } 41 | } 42 | 43 | /// [`WithHyperIo`] is [`Read`] if `I` is [`AsyncRead`]. 44 | /// 45 | /// [`AsyncRead`]: tokio::io::AsyncRead 46 | /// [`Read`]: hyper::rt::Read 47 | impl hyper::rt::Read for WithHyperIo 48 | where 49 | I: tokio::io::AsyncRead, 50 | { 51 | fn poll_read( 52 | self: Pin<&mut Self>, 53 | cx: &mut Context<'_>, 54 | mut buf: hyper::rt::ReadBufCursor<'_>, 55 | ) -> Poll> { 56 | let n = unsafe { 57 | let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); 58 | match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { 59 | Poll::Ready(Ok(())) => tbuf.filled().len(), 60 | other => return other, 61 | } 62 | }; 63 | 64 | unsafe { 65 | buf.advance(n); 66 | } 67 | Poll::Ready(Ok(())) 68 | } 69 | } 70 | 71 | /// [`WithHyperIo`] is [`Write`] if `I` is [`AsyncWrite`]. 72 | /// 73 | /// [`AsyncWrite`]: tokio::io::AsyncWrite 74 | /// [`Write`]: hyper::rt::Write 75 | impl hyper::rt::Write for WithHyperIo 76 | where 77 | I: tokio::io::AsyncWrite, 78 | { 79 | fn poll_write( 80 | self: Pin<&mut Self>, 81 | cx: &mut Context<'_>, 82 | buf: &[u8], 83 | ) -> Poll> { 84 | tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf) 85 | } 86 | 87 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 88 | tokio::io::AsyncWrite::poll_flush(self.project().inner, cx) 89 | } 90 | 91 | fn poll_shutdown( 92 | self: Pin<&mut Self>, 93 | cx: &mut Context<'_>, 94 | ) -> Poll> { 95 | tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx) 96 | } 97 | 98 | fn is_write_vectored(&self) -> bool { 99 | tokio::io::AsyncWrite::is_write_vectored(&self.inner) 100 | } 101 | 102 | fn poll_write_vectored( 103 | self: Pin<&mut Self>, 104 | cx: &mut Context<'_>, 105 | bufs: &[std::io::IoSlice<'_>], 106 | ) -> Poll> { 107 | tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) 108 | } 109 | } 110 | 111 | /// [`WithHyperIo`] exposes its inner `I`'s [`AsyncRead`] implementation. 112 | /// 113 | /// [`AsyncRead`]: tokio::io::AsyncRead 114 | impl tokio::io::AsyncRead for WithHyperIo 115 | where 116 | I: tokio::io::AsyncRead, 117 | { 118 | #[inline] 119 | fn poll_read( 120 | self: Pin<&mut Self>, 121 | cx: &mut Context<'_>, 122 | buf: &mut tokio::io::ReadBuf<'_>, 123 | ) -> Poll> { 124 | self.project().inner.poll_read(cx, buf) 125 | } 126 | } 127 | 128 | /// [`WithHyperIo`] exposes its inner `I`'s [`AsyncWrite`] implementation. 129 | /// 130 | /// [`AsyncWrite`]: tokio::io::AsyncWrite 131 | impl tokio::io::AsyncWrite for WithHyperIo 132 | where 133 | I: tokio::io::AsyncWrite, 134 | { 135 | #[inline] 136 | fn poll_write( 137 | self: Pin<&mut Self>, 138 | cx: &mut Context<'_>, 139 | buf: &[u8], 140 | ) -> Poll> { 141 | self.project().inner.poll_write(cx, buf) 142 | } 143 | 144 | #[inline] 145 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 146 | self.project().inner.poll_flush(cx) 147 | } 148 | 149 | #[inline] 150 | fn poll_shutdown( 151 | self: Pin<&mut Self>, 152 | cx: &mut Context<'_>, 153 | ) -> Poll> { 154 | self.project().inner.poll_shutdown(cx) 155 | } 156 | 157 | #[inline] 158 | fn is_write_vectored(&self) -> bool { 159 | self.inner.is_write_vectored() 160 | } 161 | 162 | #[inline] 163 | fn poll_write_vectored( 164 | self: Pin<&mut Self>, 165 | cx: &mut Context<'_>, 166 | bufs: &[std::io::IoSlice<'_>], 167 | ) -> Poll> { 168 | self.project().inner.poll_write_vectored(cx, bufs) 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /src/rt/tokio/with_tokio_io.rs: -------------------------------------------------------------------------------- 1 | use pin_project_lite::pin_project; 2 | use std::{ 3 | pin::Pin, 4 | task::{Context, Poll}, 5 | }; 6 | 7 | pin_project! { 8 | /// Extends an underlying [`hyper`] I/O with [`tokio`] I/O implementations. 9 | /// 10 | /// This implements [`AsyncRead`] and [`AsyncWrite`] given an inner type that implements 11 | /// [`Read`] and [`Write`], respectively. 12 | #[derive(Debug)] 13 | pub struct WithTokioIo { 14 | #[pin] 15 | inner: I, 16 | } 17 | } 18 | 19 | // ==== impl WithTokioIo ===== 20 | 21 | /// [`WithTokioIo`] is [`AsyncRead`] if `I` is [`Read`]. 22 | /// 23 | /// [`AsyncRead`]: tokio::io::AsyncRead 24 | /// [`Read`]: hyper::rt::Read 25 | impl tokio::io::AsyncRead for WithTokioIo 26 | where 27 | I: hyper::rt::Read, 28 | { 29 | fn poll_read( 30 | self: Pin<&mut Self>, 31 | cx: &mut Context<'_>, 32 | tbuf: &mut tokio::io::ReadBuf<'_>, 33 | ) -> Poll> { 34 | //let init = tbuf.initialized().len(); 35 | let filled = tbuf.filled().len(); 36 | let sub_filled = unsafe { 37 | let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut()); 38 | 39 | match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) { 40 | Poll::Ready(Ok(())) => buf.filled().len(), 41 | other => return other, 42 | } 43 | }; 44 | 45 | let n_filled = filled + sub_filled; 46 | // At least sub_filled bytes had to have been initialized. 47 | let n_init = sub_filled; 48 | unsafe { 49 | tbuf.assume_init(n_init); 50 | tbuf.set_filled(n_filled); 51 | } 52 | 53 | Poll::Ready(Ok(())) 54 | } 55 | } 56 | 57 | /// [`WithTokioIo`] is [`AsyncWrite`] if `I` is [`Write`]. 58 | /// 59 | /// [`AsyncWrite`]: tokio::io::AsyncWrite 60 | /// [`Write`]: hyper::rt::Write 61 | impl tokio::io::AsyncWrite for WithTokioIo 62 | where 63 | I: hyper::rt::Write, 64 | { 65 | fn poll_write( 66 | self: Pin<&mut Self>, 67 | cx: &mut Context<'_>, 68 | buf: &[u8], 69 | ) -> Poll> { 70 | hyper::rt::Write::poll_write(self.project().inner, cx, buf) 71 | } 72 | 73 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 74 | hyper::rt::Write::poll_flush(self.project().inner, cx) 75 | } 76 | 77 | fn poll_shutdown( 78 | self: Pin<&mut Self>, 79 | cx: &mut Context<'_>, 80 | ) -> Poll> { 81 | hyper::rt::Write::poll_shutdown(self.project().inner, cx) 82 | } 83 | 84 | fn is_write_vectored(&self) -> bool { 85 | hyper::rt::Write::is_write_vectored(&self.inner) 86 | } 87 | 88 | fn poll_write_vectored( 89 | self: Pin<&mut Self>, 90 | cx: &mut Context<'_>, 91 | bufs: &[std::io::IoSlice<'_>], 92 | ) -> Poll> { 93 | hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs) 94 | } 95 | } 96 | 97 | /// [`WithTokioIo`] exposes its inner `I`'s [`Write`] implementation. 98 | /// 99 | /// [`Write`]: hyper::rt::Write 100 | impl hyper::rt::Write for WithTokioIo 101 | where 102 | I: hyper::rt::Write, 103 | { 104 | #[inline] 105 | fn poll_write( 106 | self: Pin<&mut Self>, 107 | cx: &mut Context<'_>, 108 | buf: &[u8], 109 | ) -> Poll> { 110 | self.project().inner.poll_write(cx, buf) 111 | } 112 | 113 | #[inline] 114 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 115 | self.project().inner.poll_flush(cx) 116 | } 117 | 118 | #[inline] 119 | fn poll_shutdown( 120 | self: Pin<&mut Self>, 121 | cx: &mut Context<'_>, 122 | ) -> Poll> { 123 | self.project().inner.poll_shutdown(cx) 124 | } 125 | 126 | #[inline] 127 | fn is_write_vectored(&self) -> bool { 128 | self.inner.is_write_vectored() 129 | } 130 | 131 | #[inline] 132 | fn poll_write_vectored( 133 | self: Pin<&mut Self>, 134 | cx: &mut Context<'_>, 135 | bufs: &[std::io::IoSlice<'_>], 136 | ) -> Poll> { 137 | self.project().inner.poll_write_vectored(cx, bufs) 138 | } 139 | } 140 | 141 | impl WithTokioIo { 142 | /// Wraps the inner I/O in an [`WithTokioIo`] 143 | pub fn new(inner: I) -> Self { 144 | Self { inner } 145 | } 146 | 147 | /// Returns a reference to the inner type. 148 | pub fn inner(&self) -> &I { 149 | &self.inner 150 | } 151 | 152 | /// Returns a mutable reference to the inner type. 153 | pub fn inner_mut(&mut self) -> &mut I { 154 | &mut self.inner 155 | } 156 | 157 | /// Consumes this wrapper and returns the inner type. 158 | pub fn into_inner(self) -> I { 159 | self.inner 160 | } 161 | } 162 | 163 | /// [`WithTokioIo`] exposes its inner `I`'s [`Read`] implementation. 164 | /// 165 | /// [`Read`]: hyper::rt::Read 166 | impl hyper::rt::Read for WithTokioIo 167 | where 168 | I: hyper::rt::Read, 169 | { 170 | #[inline] 171 | fn poll_read( 172 | self: Pin<&mut Self>, 173 | cx: &mut Context<'_>, 174 | buf: hyper::rt::ReadBufCursor<'_>, 175 | ) -> Poll> { 176 | self.project().inner.poll_read(cx, buf) 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /src/server/conn/auto/upgrade.rs: -------------------------------------------------------------------------------- 1 | //! Upgrade utilities. 2 | 3 | use bytes::{Bytes, BytesMut}; 4 | use hyper::{ 5 | rt::{Read, Write}, 6 | upgrade::Upgraded, 7 | }; 8 | 9 | use crate::common::rewind::Rewind; 10 | 11 | /// Tries to downcast the internal trait object to the type passed. 12 | /// 13 | /// On success, returns the downcasted parts. On error, returns the Upgraded back. 14 | /// This is a kludge to work around the fact that the machinery provided by 15 | /// [`hyper_util::server::conn::auto`] wraps the inner `T` with a private type 16 | /// that is not reachable from outside the crate. 17 | /// 18 | /// [`hyper_util::server::conn::auto`]: crate::server::conn::auto 19 | /// 20 | /// This kludge will be removed when this machinery is added back to the main 21 | /// `hyper` code. 22 | pub fn downcast(upgraded: Upgraded) -> Result, Upgraded> 23 | where 24 | T: Read + Write + Unpin + 'static, 25 | { 26 | let hyper::upgrade::Parts { 27 | io: rewind, 28 | mut read_buf, 29 | .. 30 | } = upgraded.downcast::>()?; 31 | 32 | if let Some(pre) = rewind.pre { 33 | read_buf = if read_buf.is_empty() { 34 | pre 35 | } else { 36 | let mut buf = BytesMut::from(read_buf); 37 | 38 | buf.extend_from_slice(&pre); 39 | 40 | buf.freeze() 41 | }; 42 | } 43 | 44 | Ok(Parts { 45 | io: rewind.inner, 46 | read_buf, 47 | }) 48 | } 49 | 50 | /// The deconstructed parts of an [`Upgraded`] type. 51 | /// 52 | /// Includes the original IO type, and a read buffer of bytes that the 53 | /// HTTP state machine may have already read before completing an upgrade. 54 | #[derive(Debug)] 55 | #[non_exhaustive] 56 | pub struct Parts { 57 | /// The original IO object used before the upgrade. 58 | pub io: T, 59 | /// A buffer of bytes that have been read but not processed as HTTP. 60 | /// 61 | /// For instance, if the `Connection` is used for an HTTP upgrade request, 62 | /// it is possible the server sent back the first bytes of the new protocol 63 | /// along with the response upgrade. 64 | /// 65 | /// You will want to check for any existing bytes if you plan to continue 66 | /// communicating on the IO object. 67 | pub read_buf: Bytes, 68 | } 69 | -------------------------------------------------------------------------------- /src/server/conn/mod.rs: -------------------------------------------------------------------------------- 1 | //! Connection utilities. 2 | 3 | #[cfg(any(feature = "http1", feature = "http2"))] 4 | pub mod auto; 5 | -------------------------------------------------------------------------------- /src/server/graceful.rs: -------------------------------------------------------------------------------- 1 | //! Utility to gracefully shutdown a server. 2 | //! 3 | //! This module provides a [`GracefulShutdown`] type, 4 | //! which can be used to gracefully shutdown a server. 5 | //! 6 | //! See 7 | //! for an example of how to use this. 8 | 9 | use std::{ 10 | fmt::{self, Debug}, 11 | future::Future, 12 | pin::Pin, 13 | task::{self, Poll}, 14 | }; 15 | 16 | use pin_project_lite::pin_project; 17 | use tokio::sync::watch; 18 | 19 | /// A graceful shutdown utility 20 | // Purposefully not `Clone`, see `watcher()` method for why. 21 | pub struct GracefulShutdown { 22 | tx: watch::Sender<()>, 23 | } 24 | 25 | /// A watcher side of the graceful shutdown. 26 | /// 27 | /// This type can only watch a connection, it cannot trigger a shutdown. 28 | /// 29 | /// Call [`GracefulShutdown::watcher()`] to construct one of these. 30 | pub struct Watcher { 31 | rx: watch::Receiver<()>, 32 | } 33 | 34 | impl GracefulShutdown { 35 | /// Create a new graceful shutdown helper. 36 | pub fn new() -> Self { 37 | let (tx, _) = watch::channel(()); 38 | Self { tx } 39 | } 40 | 41 | /// Wrap a future for graceful shutdown watching. 42 | pub fn watch(&self, conn: C) -> impl Future { 43 | self.watcher().watch(conn) 44 | } 45 | 46 | /// Create an owned type that can watch a connection. 47 | /// 48 | /// This method allows created an owned type that can be sent onto another 49 | /// task before calling [`Watcher::watch()`]. 50 | // Internal: this function exists because `Clone` allows footguns. 51 | // If the `tx` were cloned (or the `rx`), race conditions can happens where 52 | // one task starting a shutdown is scheduled and interwined with a task 53 | // starting to watch a connection, and the "watch version" is one behind. 54 | pub fn watcher(&self) -> Watcher { 55 | let rx = self.tx.subscribe(); 56 | Watcher { rx } 57 | } 58 | 59 | /// Signal shutdown for all watched connections. 60 | /// 61 | /// This returns a `Future` which will complete once all watched 62 | /// connections have shutdown. 63 | pub async fn shutdown(self) { 64 | let Self { tx } = self; 65 | 66 | // signal all the watched futures about the change 67 | let _ = tx.send(()); 68 | // and then wait for all of them to complete 69 | tx.closed().await; 70 | } 71 | 72 | /// Returns the number of the watching connections. 73 | pub fn count(&self) -> usize { 74 | self.tx.receiver_count() 75 | } 76 | } 77 | 78 | impl Debug for GracefulShutdown { 79 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 80 | f.debug_struct("GracefulShutdown").finish() 81 | } 82 | } 83 | 84 | impl Default for GracefulShutdown { 85 | fn default() -> Self { 86 | Self::new() 87 | } 88 | } 89 | 90 | impl Watcher { 91 | /// Wrap a future for graceful shutdown watching. 92 | pub fn watch(self, conn: C) -> impl Future { 93 | let Watcher { mut rx } = self; 94 | GracefulConnectionFuture::new(conn, async move { 95 | let _ = rx.changed().await; 96 | // hold onto the rx until the watched future is completed 97 | rx 98 | }) 99 | } 100 | } 101 | 102 | impl Debug for Watcher { 103 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 104 | f.debug_struct("GracefulWatcher").finish() 105 | } 106 | } 107 | 108 | pin_project! { 109 | struct GracefulConnectionFuture { 110 | #[pin] 111 | conn: C, 112 | #[pin] 113 | cancel: F, 114 | #[pin] 115 | // If cancelled, this is held until the inner conn is done. 116 | cancelled_guard: Option, 117 | } 118 | } 119 | 120 | impl GracefulConnectionFuture { 121 | fn new(conn: C, cancel: F) -> Self { 122 | Self { 123 | conn, 124 | cancel, 125 | cancelled_guard: None, 126 | } 127 | } 128 | } 129 | 130 | impl Debug for GracefulConnectionFuture { 131 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 132 | f.debug_struct("GracefulConnectionFuture").finish() 133 | } 134 | } 135 | 136 | impl Future for GracefulConnectionFuture 137 | where 138 | C: GracefulConnection, 139 | F: Future, 140 | { 141 | type Output = C::Output; 142 | 143 | fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { 144 | let mut this = self.project(); 145 | if this.cancelled_guard.is_none() { 146 | if let Poll::Ready(guard) = this.cancel.poll(cx) { 147 | this.cancelled_guard.set(Some(guard)); 148 | this.conn.as_mut().graceful_shutdown(); 149 | } 150 | } 151 | this.conn.poll(cx) 152 | } 153 | } 154 | 155 | /// An internal utility trait as an umbrella target for all (hyper) connection 156 | /// types that the [`GracefulShutdown`] can watch. 157 | pub trait GracefulConnection: Future> + private::Sealed { 158 | /// The error type returned by the connection when used as a future. 159 | type Error; 160 | 161 | /// Start a graceful shutdown process for this connection. 162 | fn graceful_shutdown(self: Pin<&mut Self>); 163 | } 164 | 165 | #[cfg(feature = "http1")] 166 | impl GracefulConnection for hyper::server::conn::http1::Connection 167 | where 168 | S: hyper::service::HttpService, 169 | S::Error: Into>, 170 | I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, 171 | B: hyper::body::Body + 'static, 172 | B::Error: Into>, 173 | { 174 | type Error = hyper::Error; 175 | 176 | fn graceful_shutdown(self: Pin<&mut Self>) { 177 | hyper::server::conn::http1::Connection::graceful_shutdown(self); 178 | } 179 | } 180 | 181 | #[cfg(feature = "http2")] 182 | impl GracefulConnection for hyper::server::conn::http2::Connection 183 | where 184 | S: hyper::service::HttpService, 185 | S::Error: Into>, 186 | I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, 187 | B: hyper::body::Body + 'static, 188 | B::Error: Into>, 189 | E: hyper::rt::bounds::Http2ServerConnExec, 190 | { 191 | type Error = hyper::Error; 192 | 193 | fn graceful_shutdown(self: Pin<&mut Self>) { 194 | hyper::server::conn::http2::Connection::graceful_shutdown(self); 195 | } 196 | } 197 | 198 | #[cfg(feature = "server-auto")] 199 | impl GracefulConnection for crate::server::conn::auto::Connection<'_, I, S, E> 200 | where 201 | S: hyper::service::Service, Response = http::Response>, 202 | S::Error: Into>, 203 | S::Future: 'static, 204 | I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, 205 | B: hyper::body::Body + 'static, 206 | B::Error: Into>, 207 | E: hyper::rt::bounds::Http2ServerConnExec, 208 | { 209 | type Error = Box; 210 | 211 | fn graceful_shutdown(self: Pin<&mut Self>) { 212 | crate::server::conn::auto::Connection::graceful_shutdown(self); 213 | } 214 | } 215 | 216 | #[cfg(feature = "server-auto")] 217 | impl GracefulConnection 218 | for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E> 219 | where 220 | S: hyper::service::Service, Response = http::Response>, 221 | S::Error: Into>, 222 | S::Future: 'static, 223 | I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, 224 | B: hyper::body::Body + 'static, 225 | B::Error: Into>, 226 | E: hyper::rt::bounds::Http2ServerConnExec, 227 | { 228 | type Error = Box; 229 | 230 | fn graceful_shutdown(self: Pin<&mut Self>) { 231 | crate::server::conn::auto::UpgradeableConnection::graceful_shutdown(self); 232 | } 233 | } 234 | 235 | mod private { 236 | pub trait Sealed {} 237 | 238 | #[cfg(feature = "http1")] 239 | impl Sealed for hyper::server::conn::http1::Connection 240 | where 241 | S: hyper::service::HttpService, 242 | S::Error: Into>, 243 | I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, 244 | B: hyper::body::Body + 'static, 245 | B::Error: Into>, 246 | { 247 | } 248 | 249 | #[cfg(feature = "http1")] 250 | impl Sealed for hyper::server::conn::http1::UpgradeableConnection 251 | where 252 | S: hyper::service::HttpService, 253 | S::Error: Into>, 254 | I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, 255 | B: hyper::body::Body + 'static, 256 | B::Error: Into>, 257 | { 258 | } 259 | 260 | #[cfg(feature = "http2")] 261 | impl Sealed for hyper::server::conn::http2::Connection 262 | where 263 | S: hyper::service::HttpService, 264 | S::Error: Into>, 265 | I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, 266 | B: hyper::body::Body + 'static, 267 | B::Error: Into>, 268 | E: hyper::rt::bounds::Http2ServerConnExec, 269 | { 270 | } 271 | 272 | #[cfg(feature = "server-auto")] 273 | impl Sealed for crate::server::conn::auto::Connection<'_, I, S, E> 274 | where 275 | S: hyper::service::Service< 276 | http::Request, 277 | Response = http::Response, 278 | >, 279 | S::Error: Into>, 280 | S::Future: 'static, 281 | I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static, 282 | B: hyper::body::Body + 'static, 283 | B::Error: Into>, 284 | E: hyper::rt::bounds::Http2ServerConnExec, 285 | { 286 | } 287 | 288 | #[cfg(feature = "server-auto")] 289 | impl Sealed for crate::server::conn::auto::UpgradeableConnection<'_, I, S, E> 290 | where 291 | S: hyper::service::Service< 292 | http::Request, 293 | Response = http::Response, 294 | >, 295 | S::Error: Into>, 296 | S::Future: 'static, 297 | I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static, 298 | B: hyper::body::Body + 'static, 299 | B::Error: Into>, 300 | E: hyper::rt::bounds::Http2ServerConnExec, 301 | { 302 | } 303 | } 304 | 305 | #[cfg(test)] 306 | mod test { 307 | use super::*; 308 | use pin_project_lite::pin_project; 309 | use std::sync::atomic::{AtomicUsize, Ordering}; 310 | use std::sync::Arc; 311 | 312 | pin_project! { 313 | #[derive(Debug)] 314 | struct DummyConnection { 315 | #[pin] 316 | future: F, 317 | shutdown_counter: Arc, 318 | } 319 | } 320 | 321 | impl private::Sealed for DummyConnection {} 322 | 323 | impl GracefulConnection for DummyConnection { 324 | type Error = (); 325 | 326 | fn graceful_shutdown(self: Pin<&mut Self>) { 327 | self.shutdown_counter.fetch_add(1, Ordering::SeqCst); 328 | } 329 | } 330 | 331 | impl Future for DummyConnection { 332 | type Output = Result<(), ()>; 333 | 334 | fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { 335 | match self.project().future.poll(cx) { 336 | Poll::Ready(_) => Poll::Ready(Ok(())), 337 | Poll::Pending => Poll::Pending, 338 | } 339 | } 340 | } 341 | 342 | #[cfg(not(miri))] 343 | #[tokio::test] 344 | async fn test_graceful_shutdown_ok() { 345 | let graceful = GracefulShutdown::new(); 346 | let shutdown_counter = Arc::new(AtomicUsize::new(0)); 347 | let (dummy_tx, _) = tokio::sync::broadcast::channel(1); 348 | 349 | for i in 1..=3 { 350 | let mut dummy_rx = dummy_tx.subscribe(); 351 | let shutdown_counter = shutdown_counter.clone(); 352 | 353 | let future = async move { 354 | tokio::time::sleep(std::time::Duration::from_millis(i * 10)).await; 355 | let _ = dummy_rx.recv().await; 356 | }; 357 | let dummy_conn = DummyConnection { 358 | future, 359 | shutdown_counter, 360 | }; 361 | let conn = graceful.watch(dummy_conn); 362 | tokio::spawn(async move { 363 | conn.await.unwrap(); 364 | }); 365 | } 366 | 367 | assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); 368 | let _ = dummy_tx.send(()); 369 | 370 | tokio::select! { 371 | _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { 372 | panic!("timeout") 373 | }, 374 | _ = graceful.shutdown() => { 375 | assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); 376 | } 377 | } 378 | } 379 | 380 | #[cfg(not(miri))] 381 | #[tokio::test] 382 | async fn test_graceful_shutdown_delayed_ok() { 383 | let graceful = GracefulShutdown::new(); 384 | let shutdown_counter = Arc::new(AtomicUsize::new(0)); 385 | 386 | for i in 1..=3 { 387 | let shutdown_counter = shutdown_counter.clone(); 388 | 389 | //tokio::time::sleep(std::time::Duration::from_millis(i * 5)).await; 390 | let future = async move { 391 | tokio::time::sleep(std::time::Duration::from_millis(i * 50)).await; 392 | }; 393 | let dummy_conn = DummyConnection { 394 | future, 395 | shutdown_counter, 396 | }; 397 | let conn = graceful.watch(dummy_conn); 398 | tokio::spawn(async move { 399 | conn.await.unwrap(); 400 | }); 401 | } 402 | 403 | assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); 404 | 405 | tokio::select! { 406 | _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => { 407 | panic!("timeout") 408 | }, 409 | _ = graceful.shutdown() => { 410 | assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); 411 | } 412 | } 413 | } 414 | 415 | #[cfg(not(miri))] 416 | #[tokio::test] 417 | async fn test_graceful_shutdown_multi_per_watcher_ok() { 418 | let graceful = GracefulShutdown::new(); 419 | let shutdown_counter = Arc::new(AtomicUsize::new(0)); 420 | 421 | for i in 1..=3 { 422 | let shutdown_counter = shutdown_counter.clone(); 423 | 424 | let mut futures = Vec::new(); 425 | for u in 1..=i { 426 | let future = tokio::time::sleep(std::time::Duration::from_millis(u * 50)); 427 | let dummy_conn = DummyConnection { 428 | future, 429 | shutdown_counter: shutdown_counter.clone(), 430 | }; 431 | let conn = graceful.watch(dummy_conn); 432 | futures.push(conn); 433 | } 434 | tokio::spawn(async move { 435 | futures_util::future::join_all(futures).await; 436 | }); 437 | } 438 | 439 | assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); 440 | 441 | tokio::select! { 442 | _ = tokio::time::sleep(std::time::Duration::from_millis(200)) => { 443 | panic!("timeout") 444 | }, 445 | _ = graceful.shutdown() => { 446 | assert_eq!(shutdown_counter.load(Ordering::SeqCst), 6); 447 | } 448 | } 449 | } 450 | 451 | #[cfg(not(miri))] 452 | #[tokio::test] 453 | async fn test_graceful_shutdown_timeout() { 454 | let graceful = GracefulShutdown::new(); 455 | let shutdown_counter = Arc::new(AtomicUsize::new(0)); 456 | 457 | for i in 1..=3 { 458 | let shutdown_counter = shutdown_counter.clone(); 459 | 460 | let future = async move { 461 | if i == 1 { 462 | std::future::pending::<()>().await 463 | } else { 464 | std::future::ready(()).await 465 | } 466 | }; 467 | let dummy_conn = DummyConnection { 468 | future, 469 | shutdown_counter, 470 | }; 471 | let conn = graceful.watch(dummy_conn); 472 | tokio::spawn(async move { 473 | conn.await.unwrap(); 474 | }); 475 | } 476 | 477 | assert_eq!(shutdown_counter.load(Ordering::SeqCst), 0); 478 | 479 | tokio::select! { 480 | _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => { 481 | assert_eq!(shutdown_counter.load(Ordering::SeqCst), 3); 482 | }, 483 | _ = graceful.shutdown() => { 484 | panic!("shutdown should not be completed: as not all our conns finish") 485 | } 486 | } 487 | } 488 | } 489 | -------------------------------------------------------------------------------- /src/server/mod.rs: -------------------------------------------------------------------------------- 1 | //! Server utilities. 2 | 3 | pub mod conn; 4 | 5 | #[cfg(feature = "server-graceful")] 6 | pub mod graceful; 7 | -------------------------------------------------------------------------------- /src/service/glue.rs: -------------------------------------------------------------------------------- 1 | use pin_project_lite::pin_project; 2 | use std::{ 3 | future::Future, 4 | pin::Pin, 5 | task::{Context, Poll}, 6 | }; 7 | 8 | use super::Oneshot; 9 | 10 | /// A tower [`Service`][tower-svc] converted into a hyper [`Service`][hyper-svc]. 11 | /// 12 | /// This wraps an inner tower service `S` in a [`hyper::service::Service`] implementation. See 13 | /// the module-level documentation of [`service`][crate::service] for more information about using 14 | /// [`tower`][tower] services and middleware with [`hyper`]. 15 | /// 16 | /// [hyper-svc]: hyper::service::Service 17 | /// [tower]: https://docs.rs/tower/latest/tower/ 18 | /// [tower-svc]: https://docs.rs/tower/latest/tower/trait.Service.html 19 | #[derive(Debug, Copy, Clone)] 20 | pub struct TowerToHyperService { 21 | service: S, 22 | } 23 | 24 | impl TowerToHyperService { 25 | /// Create a new [`TowerToHyperService`] from a tower service. 26 | pub fn new(tower_service: S) -> Self { 27 | Self { 28 | service: tower_service, 29 | } 30 | } 31 | } 32 | 33 | impl hyper::service::Service for TowerToHyperService 34 | where 35 | S: tower_service::Service + Clone, 36 | { 37 | type Response = S::Response; 38 | type Error = S::Error; 39 | type Future = TowerToHyperServiceFuture; 40 | 41 | fn call(&self, req: R) -> Self::Future { 42 | TowerToHyperServiceFuture { 43 | future: Oneshot::new(self.service.clone(), req), 44 | } 45 | } 46 | } 47 | 48 | pin_project! { 49 | /// Response future for [`TowerToHyperService`]. 50 | /// 51 | /// This future is acquired by [`call`][hyper::service::Service::call]ing a 52 | /// [`TowerToHyperService`]. 53 | pub struct TowerToHyperServiceFuture 54 | where 55 | S: tower_service::Service, 56 | { 57 | #[pin] 58 | future: Oneshot, 59 | } 60 | } 61 | 62 | impl Future for TowerToHyperServiceFuture 63 | where 64 | S: tower_service::Service, 65 | { 66 | type Output = Result; 67 | 68 | #[inline] 69 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 70 | self.project().future.poll(cx) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/service/mod.rs: -------------------------------------------------------------------------------- 1 | //! Service utilities. 2 | //! 3 | //! [`hyper::service`] provides a [`Service`][hyper-svc] trait, representing an asynchronous 4 | //! function from a `Request` to a `Response`. This provides an interface allowing middleware for 5 | //! network application to be written in a modular and reusable way. 6 | //! 7 | //! This submodule provides an assortment of utilities for working with [`Service`][hyper-svc]s. 8 | //! See the module-level documentation of [`hyper::service`] for more information. 9 | //! 10 | //! # Tower 11 | //! 12 | //! While [`hyper`] uses its own notion of a [`Service`][hyper-svc] internally, many other 13 | //! libraries use a library such as [`tower`][tower] to provide the fundamental model of an 14 | //! asynchronous function. 15 | //! 16 | //! The [`TowerToHyperService`] type provided by this submodule can be used to bridge these 17 | //! ecosystems together. By wrapping a [`tower::Service`][tower-svc] in [`TowerToHyperService`], 18 | //! it can be passed into [`hyper`] interfaces that expect a [`hyper::service::Service`]. 19 | //! 20 | //! [hyper-svc]: hyper::service::Service 21 | //! [tower]: https://docs.rs/tower/latest/tower/ 22 | //! [tower-svc]: https://docs.rs/tower/latest/tower/trait.Service.html 23 | 24 | #[cfg(feature = "service")] 25 | mod glue; 26 | #[cfg(any(feature = "client-legacy", feature = "service"))] 27 | mod oneshot; 28 | 29 | #[cfg(feature = "service")] 30 | pub use self::glue::{TowerToHyperService, TowerToHyperServiceFuture}; 31 | #[cfg(any(feature = "client-legacy", feature = "service"))] 32 | pub(crate) use self::oneshot::Oneshot; 33 | -------------------------------------------------------------------------------- /src/service/oneshot.rs: -------------------------------------------------------------------------------- 1 | use futures_core::ready; 2 | use pin_project_lite::pin_project; 3 | use std::future::Future; 4 | use std::pin::Pin; 5 | use std::task::{Context, Poll}; 6 | use tower_service::Service; 7 | 8 | // Vendored from tower::util to reduce dependencies, the code is small enough. 9 | 10 | // Not really pub, but used in a trait for bounds 11 | pin_project! { 12 | #[project = OneshotProj] 13 | #[derive(Debug)] 14 | pub enum Oneshot, Req> { 15 | NotReady { 16 | svc: S, 17 | req: Option, 18 | }, 19 | Called { 20 | #[pin] 21 | fut: S::Future, 22 | }, 23 | Done, 24 | } 25 | } 26 | 27 | impl Oneshot 28 | where 29 | S: Service, 30 | { 31 | pub(crate) const fn new(svc: S, req: Req) -> Self { 32 | Oneshot::NotReady { 33 | svc, 34 | req: Some(req), 35 | } 36 | } 37 | } 38 | 39 | impl Future for Oneshot 40 | where 41 | S: Service, 42 | { 43 | type Output = Result; 44 | 45 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 46 | loop { 47 | let this = self.as_mut().project(); 48 | match this { 49 | OneshotProj::NotReady { svc, req } => { 50 | ready!(svc.poll_ready(cx))?; 51 | let fut = svc.call(req.take().expect("already called")); 52 | self.set(Oneshot::Called { fut }); 53 | } 54 | OneshotProj::Called { fut } => { 55 | let res = ready!(fut.poll(cx))?; 56 | self.set(Oneshot::Done); 57 | return Poll::Ready(Ok(res)); 58 | } 59 | OneshotProj::Done => panic!("polled after complete"), 60 | } 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /tests/proxy.rs: -------------------------------------------------------------------------------- 1 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 2 | use tokio::net::{TcpListener, TcpStream}; 3 | use tower_service::Service; 4 | 5 | use hyper_util::client::legacy::connect::proxy::{SocksV4, SocksV5, Tunnel}; 6 | use hyper_util::client::legacy::connect::HttpConnector; 7 | 8 | #[cfg(not(miri))] 9 | #[tokio::test] 10 | async fn test_tunnel_works() { 11 | let tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 12 | let addr = tcp.local_addr().expect("local_addr"); 13 | 14 | let proxy_dst = format!("http://{addr}").parse().expect("uri"); 15 | let mut connector = Tunnel::new(proxy_dst, HttpConnector::new()); 16 | let t1 = tokio::spawn(async move { 17 | let _conn = connector 18 | .call("https://hyper.rs".parse().unwrap()) 19 | .await 20 | .expect("tunnel"); 21 | }); 22 | 23 | let t2 = tokio::spawn(async move { 24 | let (mut io, _) = tcp.accept().await.expect("accept"); 25 | let mut buf = [0u8; 64]; 26 | let n = io.read(&mut buf).await.expect("read 1"); 27 | assert_eq!( 28 | &buf[..n], 29 | b"CONNECT hyper.rs:443 HTTP/1.1\r\nHost: hyper.rs:443\r\n\r\n" 30 | ); 31 | io.write_all(b"HTTP/1.1 200 OK\r\n\r\n") 32 | .await 33 | .expect("write 1"); 34 | }); 35 | 36 | t1.await.expect("task 1"); 37 | t2.await.expect("task 2"); 38 | } 39 | 40 | #[cfg(not(miri))] 41 | #[tokio::test] 42 | async fn test_socks_v5_without_auth_works() { 43 | let proxy_tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 44 | let proxy_addr = proxy_tcp.local_addr().expect("local_addr"); 45 | let proxy_dst = format!("http://{proxy_addr}").parse().expect("uri"); 46 | 47 | let target_tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 48 | let target_addr = target_tcp.local_addr().expect("local_addr"); 49 | let target_dst = format!("http://{target_addr}").parse().expect("uri"); 50 | 51 | let mut connector = SocksV5::new(proxy_dst, HttpConnector::new()); 52 | 53 | // Client 54 | // 55 | // Will use `SocksV5` to establish proxy tunnel. 56 | // Will send "Hello World!" to the target and receive "Goodbye!" back. 57 | let t1 = tokio::spawn(async move { 58 | let conn = connector.call(target_dst).await.expect("tunnel"); 59 | let mut tcp = conn.into_inner(); 60 | 61 | tcp.write_all(b"Hello World!").await.expect("write 1"); 62 | 63 | let mut buf = [0u8; 64]; 64 | let n = tcp.read(&mut buf).await.expect("read 1"); 65 | assert_eq!(&buf[..n], b"Goodbye!"); 66 | }); 67 | 68 | // Proxy 69 | // 70 | // Will receive CONNECT command from client. 71 | // Will connect to target and success code back to client. 72 | // Will blindly tunnel between client and target. 73 | let t2 = tokio::spawn(async move { 74 | let (mut to_client, _) = proxy_tcp.accept().await.expect("accept"); 75 | let mut buf = [0u8; 513]; 76 | 77 | // negotiation req/res 78 | let n = to_client.read(&mut buf).await.expect("read 1"); 79 | assert_eq!(&buf[..n], [0x05, 0x01, 0x00]); 80 | 81 | to_client.write_all(&[0x05, 0x00]).await.expect("write 1"); 82 | 83 | // command req/rs 84 | let [p1, p2] = target_addr.port().to_be_bytes(); 85 | let [ip1, ip2, ip3, ip4] = [0x7f, 0x00, 0x00, 0x01]; 86 | let message = [0x05, 0x01, 0x00, 0x01, ip1, ip2, ip3, ip4, p1, p2]; 87 | let n = to_client.read(&mut buf).await.expect("read 2"); 88 | assert_eq!(&buf[..n], message); 89 | 90 | let mut to_target = TcpStream::connect(target_addr).await.expect("connect"); 91 | 92 | let message = [0x05, 0x00, 0x00, 0x01, ip1, ip2, ip3, ip4, p1, p2]; 93 | to_client.write_all(&message).await.expect("write 2"); 94 | 95 | let (from_client, from_target) = 96 | tokio::io::copy_bidirectional(&mut to_client, &mut to_target) 97 | .await 98 | .expect("proxy"); 99 | 100 | assert_eq!(from_client, 12); 101 | assert_eq!(from_target, 8) 102 | }); 103 | 104 | // Target server 105 | // 106 | // Will accept connection from proxy server 107 | // Will receive "Hello World!" from the client and return "Goodbye!" 108 | let t3 = tokio::spawn(async move { 109 | let (mut io, _) = target_tcp.accept().await.expect("accept"); 110 | let mut buf = [0u8; 64]; 111 | 112 | let n = io.read(&mut buf).await.expect("read 1"); 113 | assert_eq!(&buf[..n], b"Hello World!"); 114 | 115 | io.write_all(b"Goodbye!").await.expect("write 1"); 116 | }); 117 | 118 | t1.await.expect("task - client"); 119 | t2.await.expect("task - proxy"); 120 | t3.await.expect("task - target"); 121 | } 122 | 123 | #[cfg(not(miri))] 124 | #[tokio::test] 125 | async fn test_socks_v5_with_auth_works() { 126 | let proxy_tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 127 | let proxy_addr = proxy_tcp.local_addr().expect("local_addr"); 128 | let proxy_dst = format!("http://{proxy_addr}").parse().expect("uri"); 129 | 130 | let target_tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 131 | let target_addr = target_tcp.local_addr().expect("local_addr"); 132 | let target_dst = format!("http://{target_addr}").parse().expect("uri"); 133 | 134 | let mut connector = 135 | SocksV5::new(proxy_dst, HttpConnector::new()).with_auth("user".into(), "pass".into()); 136 | 137 | // Client 138 | // 139 | // Will use `SocksV5` to establish proxy tunnel. 140 | // Will send "Hello World!" to the target and receive "Goodbye!" back. 141 | let t1 = tokio::spawn(async move { 142 | let conn = connector.call(target_dst).await.expect("tunnel"); 143 | let mut tcp = conn.into_inner(); 144 | 145 | tcp.write_all(b"Hello World!").await.expect("write 1"); 146 | 147 | let mut buf = [0u8; 64]; 148 | let n = tcp.read(&mut buf).await.expect("read 1"); 149 | assert_eq!(&buf[..n], b"Goodbye!"); 150 | }); 151 | 152 | // Proxy 153 | // 154 | // Will receive CONNECT command from client. 155 | // Will connect to target and success code back to client. 156 | // Will blindly tunnel between client and target. 157 | let t2 = tokio::spawn(async move { 158 | let (mut to_client, _) = proxy_tcp.accept().await.expect("accept"); 159 | let mut buf = [0u8; 513]; 160 | 161 | // negotiation req/res 162 | let n = to_client.read(&mut buf).await.expect("read 1"); 163 | assert_eq!(&buf[..n], [0x05, 0x01, 0x02]); 164 | 165 | to_client.write_all(&[0x05, 0x02]).await.expect("write 1"); 166 | 167 | // auth req/res 168 | let n = to_client.read(&mut buf).await.expect("read 2"); 169 | let [u1, u2, u3, u4] = b"user"; 170 | let [p1, p2, p3, p4] = b"pass"; 171 | let message = [0x01, 0x04, *u1, *u2, *u3, *u4, 0x04, *p1, *p2, *p3, *p4]; 172 | assert_eq!(&buf[..n], message); 173 | 174 | to_client.write_all(&[0x01, 0x00]).await.expect("write 2"); 175 | 176 | // command req/res 177 | let n = to_client.read(&mut buf).await.expect("read 3"); 178 | let [p1, p2] = target_addr.port().to_be_bytes(); 179 | let [ip1, ip2, ip3, ip4] = [0x7f, 0x00, 0x00, 0x01]; 180 | let message = [0x05, 0x01, 0x00, 0x01, ip1, ip2, ip3, ip4, p1, p2]; 181 | assert_eq!(&buf[..n], message); 182 | 183 | let mut to_target = TcpStream::connect(target_addr).await.expect("connect"); 184 | 185 | let message = [0x05, 0x00, 0x00, 0x01, ip1, ip2, ip3, ip4, p1, p2]; 186 | to_client.write_all(&message).await.expect("write 3"); 187 | 188 | let (from_client, from_target) = 189 | tokio::io::copy_bidirectional(&mut to_client, &mut to_target) 190 | .await 191 | .expect("proxy"); 192 | 193 | assert_eq!(from_client, 12); 194 | assert_eq!(from_target, 8) 195 | }); 196 | 197 | // Target server 198 | // 199 | // Will accept connection from proxy server 200 | // Will receive "Hello World!" from the client and return "Goodbye!" 201 | let t3 = tokio::spawn(async move { 202 | let (mut io, _) = target_tcp.accept().await.expect("accept"); 203 | let mut buf = [0u8; 64]; 204 | 205 | let n = io.read(&mut buf).await.expect("read 1"); 206 | assert_eq!(&buf[..n], b"Hello World!"); 207 | 208 | io.write_all(b"Goodbye!").await.expect("write 1"); 209 | }); 210 | 211 | t1.await.expect("task - client"); 212 | t2.await.expect("task - proxy"); 213 | t3.await.expect("task - target"); 214 | } 215 | 216 | #[cfg(not(miri))] 217 | #[tokio::test] 218 | async fn test_socks_v5_with_server_resolved_domain_works() { 219 | let proxy_tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 220 | let proxy_addr = proxy_tcp.local_addr().expect("local_addr"); 221 | let proxy_addr = format!("http://{proxy_addr}").parse().expect("uri"); 222 | 223 | let mut connector = SocksV5::new(proxy_addr, HttpConnector::new()) 224 | .with_auth("user".into(), "pass".into()) 225 | .local_dns(false); 226 | 227 | // Client 228 | // 229 | // Will use `SocksV5` to establish proxy tunnel. 230 | // Will send "Hello World!" to the target and receive "Goodbye!" back. 231 | let t1 = tokio::spawn(async move { 232 | let _conn = connector 233 | .call("https://hyper.rs:443".try_into().unwrap()) 234 | .await 235 | .expect("tunnel"); 236 | }); 237 | 238 | // Proxy 239 | // 240 | // Will receive CONNECT command from client. 241 | // Will connect to target and success code back to client. 242 | // Will blindly tunnel between client and target. 243 | let t2 = tokio::spawn(async move { 244 | let (mut to_client, _) = proxy_tcp.accept().await.expect("accept"); 245 | let mut buf = [0u8; 513]; 246 | 247 | // negotiation req/res 248 | let n = to_client.read(&mut buf).await.expect("read 1"); 249 | assert_eq!(&buf[..n], [0x05, 0x01, 0x02]); 250 | 251 | to_client.write_all(&[0x05, 0x02]).await.expect("write 1"); 252 | 253 | // auth req/res 254 | let n = to_client.read(&mut buf).await.expect("read 2"); 255 | let [u1, u2, u3, u4] = b"user"; 256 | let [p1, p2, p3, p4] = b"pass"; 257 | let message = [0x01, 0x04, *u1, *u2, *u3, *u4, 0x04, *p1, *p2, *p3, *p4]; 258 | assert_eq!(&buf[..n], message); 259 | 260 | to_client.write_all(&[0x01, 0x00]).await.expect("write 2"); 261 | 262 | // command req/res 263 | let n = to_client.read(&mut buf).await.expect("read 3"); 264 | 265 | let host = "hyper.rs"; 266 | let port: u16 = 443; 267 | let mut message = vec![0x05, 0x01, 0x00, 0x03, host.len() as u8]; 268 | message.extend(host.bytes()); 269 | message.extend(port.to_be_bytes()); 270 | assert_eq!(&buf[..n], message); 271 | 272 | let mut message = vec![0x05, 0x00, 0x00, 0x03, host.len() as u8]; 273 | message.extend(host.bytes()); 274 | message.extend(port.to_be_bytes()); 275 | to_client.write_all(&message).await.expect("write 3"); 276 | }); 277 | 278 | t1.await.expect("task - client"); 279 | t2.await.expect("task - proxy"); 280 | } 281 | 282 | #[cfg(not(miri))] 283 | #[tokio::test] 284 | async fn test_socks_v5_with_locally_resolved_domain_works() { 285 | let proxy_tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 286 | let proxy_addr = proxy_tcp.local_addr().expect("local_addr"); 287 | let proxy_addr = format!("http://{proxy_addr}").parse().expect("uri"); 288 | 289 | let mut connector = SocksV5::new(proxy_addr, HttpConnector::new()) 290 | .with_auth("user".into(), "pass".into()) 291 | .local_dns(true); 292 | 293 | // Client 294 | // 295 | // Will use `SocksV5` to establish proxy tunnel. 296 | // Will send "Hello World!" to the target and receive "Goodbye!" back. 297 | let t1 = tokio::spawn(async move { 298 | let _conn = connector 299 | .call("https://hyper.rs:443".try_into().unwrap()) 300 | .await 301 | .expect("tunnel"); 302 | }); 303 | 304 | // Proxy 305 | // 306 | // Will receive CONNECT command from client. 307 | // Will connect to target and success code back to client. 308 | // Will blindly tunnel between client and target. 309 | let t2 = tokio::spawn(async move { 310 | let (mut to_client, _) = proxy_tcp.accept().await.expect("accept"); 311 | let mut buf = [0u8; 513]; 312 | 313 | // negotiation req/res 314 | let n = to_client.read(&mut buf).await.expect("read 1"); 315 | assert_eq!(&buf[..n], [0x05, 0x01, 0x02]); 316 | 317 | to_client.write_all(&[0x05, 0x02]).await.expect("write 1"); 318 | 319 | // auth req/res 320 | let n = to_client.read(&mut buf).await.expect("read 2"); 321 | let [u1, u2, u3, u4] = b"user"; 322 | let [p1, p2, p3, p4] = b"pass"; 323 | let message = [0x01, 0x04, *u1, *u2, *u3, *u4, 0x04, *p1, *p2, *p3, *p4]; 324 | assert_eq!(&buf[..n], message); 325 | 326 | to_client.write_all(&[0x01, 0x00]).await.expect("write 2"); 327 | 328 | // command req/res 329 | let n = to_client.read(&mut buf).await.expect("read 3"); 330 | let message = [0x05, 0x01, 0x00]; 331 | assert_eq!(&buf[..3], message); 332 | assert!(buf[3] == 0x01 || buf[3] == 0x04); // IPv4 or IPv6 333 | assert_eq!(n, 4 + 4 * (buf[3] as usize) + 2); 334 | 335 | let message = vec![0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0]; 336 | to_client.write_all(&message).await.expect("write 3"); 337 | }); 338 | 339 | t1.await.expect("task - client"); 340 | t2.await.expect("task - proxy"); 341 | } 342 | 343 | #[cfg(not(miri))] 344 | #[tokio::test] 345 | async fn test_socks_v4_works() { 346 | let proxy_tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 347 | let proxy_addr = proxy_tcp.local_addr().expect("local_addr"); 348 | let proxy_dst = format!("http://{proxy_addr}").parse().expect("uri"); 349 | 350 | let target_tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 351 | let target_addr = target_tcp.local_addr().expect("local_addr"); 352 | let target_dst = format!("http://{target_addr}").parse().expect("uri"); 353 | 354 | let mut connector = SocksV4::new(proxy_dst, HttpConnector::new()); 355 | 356 | // Client 357 | // 358 | // Will use `SocksV4` to establish proxy tunnel. 359 | // Will send "Hello World!" to the target and receive "Goodbye!" back. 360 | let t1 = tokio::spawn(async move { 361 | let conn = connector.call(target_dst).await.expect("tunnel"); 362 | let mut tcp = conn.into_inner(); 363 | 364 | tcp.write_all(b"Hello World!").await.expect("write 1"); 365 | 366 | let mut buf = [0u8; 64]; 367 | let n = tcp.read(&mut buf).await.expect("read 1"); 368 | assert_eq!(&buf[..n], b"Goodbye!"); 369 | }); 370 | 371 | // Proxy 372 | // 373 | // Will receive CONNECT command from client. 374 | // Will connect to target and success code back to client. 375 | // Will blindly tunnel between client and target. 376 | let t2 = tokio::spawn(async move { 377 | let (mut to_client, _) = proxy_tcp.accept().await.expect("accept"); 378 | let mut buf = [0u8; 512]; 379 | 380 | let [p1, p2] = target_addr.port().to_be_bytes(); 381 | let [ip1, ip2, ip3, ip4] = [127, 0, 0, 1]; 382 | let message = [4, 0x01, p1, p2, ip1, ip2, ip3, ip4, 0, 0]; 383 | let n = to_client.read(&mut buf).await.expect("read"); 384 | assert_eq!(&buf[..n], message); 385 | 386 | let mut to_target = TcpStream::connect(target_addr).await.expect("connect"); 387 | 388 | let message = [0, 90, p1, p2, ip1, ip2, ip3, ip4]; 389 | to_client.write_all(&message).await.expect("write"); 390 | 391 | let (from_client, from_target) = 392 | tokio::io::copy_bidirectional(&mut to_client, &mut to_target) 393 | .await 394 | .expect("proxy"); 395 | 396 | assert_eq!(from_client, 12); 397 | assert_eq!(from_target, 8) 398 | }); 399 | 400 | // Target server 401 | // 402 | // Will accept connection from proxy server 403 | // Will receive "Hello World!" from the client and return "Goodbye!" 404 | let t3 = tokio::spawn(async move { 405 | let (mut io, _) = target_tcp.accept().await.expect("accept"); 406 | let mut buf = [0u8; 64]; 407 | 408 | let n = io.read(&mut buf).await.expect("read 1"); 409 | assert_eq!(&buf[..n], b"Hello World!"); 410 | 411 | io.write_all(b"Goodbye!").await.expect("write 1"); 412 | }); 413 | 414 | t1.await.expect("task - client"); 415 | t2.await.expect("task - proxy"); 416 | t3.await.expect("task - target"); 417 | } 418 | 419 | #[cfg(not(miri))] 420 | #[tokio::test] 421 | async fn test_socks_v5_optimistic_works() { 422 | let proxy_tcp = TcpListener::bind("127.0.0.1:0").await.expect("bind"); 423 | let proxy_addr = proxy_tcp.local_addr().expect("local_addr"); 424 | let proxy_dst = format!("http://{proxy_addr}").parse().expect("uri"); 425 | 426 | let target_addr = std::net::SocketAddr::new([127, 0, 0, 1].into(), 1234); 427 | let target_dst = format!("http://{target_addr}").parse().expect("uri"); 428 | 429 | let mut connector = SocksV5::new(proxy_dst, HttpConnector::new()) 430 | .with_auth("ABC".into(), "XYZ".into()) 431 | .send_optimistically(true); 432 | 433 | // Client 434 | // 435 | // Will use `SocksV5` to establish proxy tunnel. 436 | // Will send "Hello World!" to the target and receive "Goodbye!" back. 437 | let t1 = tokio::spawn(async move { 438 | let _ = connector.call(target_dst).await.expect("tunnel"); 439 | }); 440 | 441 | // Proxy 442 | // 443 | // Will receive SOCKS handshake from client. 444 | // Will connect to target and success code back to client. 445 | // Will blindly tunnel between client and target. 446 | let t2 = tokio::spawn(async move { 447 | let (mut to_client, _) = proxy_tcp.accept().await.expect("accept"); 448 | let [p1, p2] = target_addr.port().to_be_bytes(); 449 | 450 | let mut buf = [0; 22]; 451 | let request = vec![ 452 | 5, 1, 2, // Negotiation 453 | 1, 3, 65, 66, 67, 3, 88, 89, 90, // Auth ("ABC"/"XYZ") 454 | 5, 1, 0, 1, 127, 0, 0, 1, p1, p2, // Reply 455 | ]; 456 | 457 | let response = vec![ 458 | 5, 2, // Negotiation, 459 | 1, 0, // Auth, 460 | 5, 0, 0, 1, 127, 0, 0, 1, p1, p2, // Reply 461 | ]; 462 | 463 | // Accept all handshake messages 464 | to_client.read_exact(&mut buf).await.expect("read"); 465 | assert_eq!(request.as_slice(), buf); 466 | 467 | // Send all handshake messages back 468 | to_client 469 | .write_all(response.as_slice()) 470 | .await 471 | .expect("write"); 472 | 473 | to_client.flush().await.expect("flush"); 474 | }); 475 | 476 | t1.await.expect("task - client"); 477 | t2.await.expect("task - proxy"); 478 | } 479 | -------------------------------------------------------------------------------- /tests/test_utils/mod.rs: -------------------------------------------------------------------------------- 1 | use std::pin::Pin; 2 | use std::sync::atomic::{AtomicUsize, Ordering}; 3 | use std::sync::Arc; 4 | 5 | use futures_channel::mpsc; 6 | use futures_util::task::{Context, Poll}; 7 | use futures_util::Future; 8 | use futures_util::TryFutureExt; 9 | use hyper::Uri; 10 | use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf}; 11 | use tokio::net::TcpStream; 12 | 13 | use hyper::rt::ReadBufCursor; 14 | 15 | use hyper_util::client::legacy::connect::HttpConnector; 16 | use hyper_util::client::legacy::connect::{Connected, Connection}; 17 | use hyper_util::rt::TokioIo; 18 | 19 | #[derive(Clone)] 20 | pub struct DebugConnector { 21 | pub http: HttpConnector, 22 | pub closes: mpsc::Sender<()>, 23 | pub connects: Arc, 24 | pub is_proxy: bool, 25 | pub alpn_h2: bool, 26 | } 27 | 28 | impl DebugConnector { 29 | pub fn new() -> DebugConnector { 30 | let http = HttpConnector::new(); 31 | let (tx, _) = mpsc::channel(10); 32 | DebugConnector::with_http_and_closes(http, tx) 33 | } 34 | 35 | pub fn with_http_and_closes(http: HttpConnector, closes: mpsc::Sender<()>) -> DebugConnector { 36 | DebugConnector { 37 | http, 38 | closes, 39 | connects: Arc::new(AtomicUsize::new(0)), 40 | is_proxy: false, 41 | alpn_h2: false, 42 | } 43 | } 44 | 45 | pub fn proxy(mut self) -> Self { 46 | self.is_proxy = true; 47 | self 48 | } 49 | } 50 | 51 | impl tower_service::Service for DebugConnector { 52 | type Response = DebugStream; 53 | type Error = >::Error; 54 | type Future = Pin> + Send>>; 55 | 56 | fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { 57 | // don't forget to check inner service is ready :) 58 | tower_service::Service::::poll_ready(&mut self.http, cx) 59 | } 60 | 61 | fn call(&mut self, dst: Uri) -> Self::Future { 62 | self.connects.fetch_add(1, Ordering::SeqCst); 63 | let closes = self.closes.clone(); 64 | let is_proxy = self.is_proxy; 65 | let is_alpn_h2 = self.alpn_h2; 66 | Box::pin(self.http.call(dst).map_ok(move |tcp| DebugStream { 67 | tcp, 68 | on_drop: closes, 69 | is_alpn_h2, 70 | is_proxy, 71 | })) 72 | } 73 | } 74 | 75 | pub struct DebugStream { 76 | tcp: TokioIo, 77 | on_drop: mpsc::Sender<()>, 78 | is_alpn_h2: bool, 79 | is_proxy: bool, 80 | } 81 | 82 | impl Drop for DebugStream { 83 | fn drop(&mut self) { 84 | let _ = self.on_drop.try_send(()); 85 | } 86 | } 87 | 88 | impl Connection for DebugStream { 89 | fn connected(&self) -> Connected { 90 | let connected = self.tcp.connected().proxy(self.is_proxy); 91 | 92 | if self.is_alpn_h2 { 93 | connected.negotiated_h2() 94 | } else { 95 | connected 96 | } 97 | } 98 | } 99 | 100 | impl hyper::rt::Read for DebugStream { 101 | fn poll_read( 102 | mut self: Pin<&mut Self>, 103 | cx: &mut Context<'_>, 104 | buf: ReadBufCursor<'_>, 105 | ) -> Poll> { 106 | hyper::rt::Read::poll_read(Pin::new(&mut self.tcp), cx, buf) 107 | } 108 | } 109 | 110 | impl hyper::rt::Write for DebugStream { 111 | fn poll_write( 112 | mut self: Pin<&mut Self>, 113 | cx: &mut Context<'_>, 114 | buf: &[u8], 115 | ) -> Poll> { 116 | hyper::rt::Write::poll_write(Pin::new(&mut self.tcp), cx, buf) 117 | } 118 | 119 | fn poll_flush( 120 | mut self: Pin<&mut Self>, 121 | cx: &mut Context<'_>, 122 | ) -> Poll> { 123 | hyper::rt::Write::poll_flush(Pin::new(&mut self.tcp), cx) 124 | } 125 | 126 | fn poll_shutdown( 127 | mut self: Pin<&mut Self>, 128 | cx: &mut Context<'_>, 129 | ) -> Poll> { 130 | hyper::rt::Write::poll_shutdown(Pin::new(&mut self.tcp), cx) 131 | } 132 | 133 | fn is_write_vectored(&self) -> bool { 134 | hyper::rt::Write::is_write_vectored(&self.tcp) 135 | } 136 | 137 | fn poll_write_vectored( 138 | mut self: Pin<&mut Self>, 139 | cx: &mut Context<'_>, 140 | bufs: &[std::io::IoSlice<'_>], 141 | ) -> Poll> { 142 | hyper::rt::Write::poll_write_vectored(Pin::new(&mut self.tcp), cx, bufs) 143 | } 144 | } 145 | 146 | impl AsyncWrite for DebugStream { 147 | fn poll_shutdown( 148 | mut self: Pin<&mut Self>, 149 | cx: &mut Context<'_>, 150 | ) -> Poll> { 151 | Pin::new(self.tcp.inner_mut()).poll_shutdown(cx) 152 | } 153 | 154 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 155 | Pin::new(self.tcp.inner_mut()).poll_flush(cx) 156 | } 157 | 158 | fn poll_write( 159 | mut self: Pin<&mut Self>, 160 | cx: &mut Context<'_>, 161 | buf: &[u8], 162 | ) -> Poll> { 163 | Pin::new(self.tcp.inner_mut()).poll_write(cx, buf) 164 | } 165 | } 166 | 167 | impl AsyncRead for DebugStream { 168 | fn poll_read( 169 | mut self: Pin<&mut Self>, 170 | cx: &mut Context<'_>, 171 | buf: &mut ReadBuf<'_>, 172 | ) -> Poll> { 173 | Pin::new(self.tcp.inner_mut()).poll_read(cx, buf) 174 | } 175 | } 176 | --------------------------------------------------------------------------------