├── .github ├── dependabot.yml └── workflows │ ├── ci.yml │ └── publish.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── cat.rs ├── charset.rs ├── head.rs ├── imdb.rs ├── multipart.rs ├── nhlapi.rs ├── post.rs ├── post_json.rs └── session.rs ├── rustfmt.toml ├── src ├── charsets.rs ├── error.rs ├── happy.rs ├── lib.rs ├── multipart.rs ├── multipart_crate │ ├── lazy.rs │ └── mod.rs ├── parsing │ ├── body_reader.rs │ ├── buffers.rs │ ├── chunked_reader.rs │ ├── compressed_reader.rs │ ├── mod.rs │ ├── response.rs │ ├── response_reader.rs │ └── text_reader.rs ├── request │ ├── body.rs │ ├── builder.rs │ ├── mod.rs │ ├── proxy.rs │ ├── session.rs │ └── settings.rs ├── streams.rs └── tls │ ├── mod.rs │ ├── native_tls_impl.rs │ ├── no_tls_impl.rs │ └── rustls_impl.rs ├── tests ├── test_invalid_certs.rs ├── test_multipart.rs ├── test_proxy.rs ├── test_redirection.rs ├── test_timeout.rs └── tools │ ├── cert.pem │ ├── generate-certs.bash │ ├── key.pem │ ├── mod.rs │ ├── proxy.rs │ └── servers.rs └── tools ├── clippy.bash └── tests.bash /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "cargo" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | groups: 13 | all: 14 | patterns: 15 | - "*" 16 | update-types: 17 | - "minor" 18 | - "patch" 19 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Test & Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | env: 12 | CARGO_TERM_COLOR: always 13 | 14 | jobs: 15 | test: 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | os: [ubuntu-latest, macos-latest, windows-latest] 20 | runs-on: ${{ matrix.os }} 21 | steps: 22 | - uses: actions/checkout@v4 23 | - uses: actions-rs/toolchain@v1 24 | with: 25 | toolchain: stable 26 | - uses: taiki-e/install-action@v2 27 | with: 28 | tool: nextest 29 | - run: ./tools/tests.bash 30 | shell: bash 31 | 32 | clippy: 33 | runs-on: ubuntu-latest 34 | steps: 35 | - uses: actions/checkout@v4 36 | - uses: actions-rs/toolchain@v1 37 | with: 38 | toolchain: stable 39 | - run: rustup component add clippy 40 | - run: cargo fmt --check 41 | - uses: actions-rs/clippy-check@v1 42 | with: 43 | token: ${{ secrets.GITHUB_TOKEN }} 44 | args: --all-features --all-targets 45 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to crates.io 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | publish: 10 | name: Publish to crates.io 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout sources 14 | uses: actions/checkout@v4 15 | - name: Install stable toolchain 16 | uses: actions-rs/toolchain@v1 17 | with: 18 | profile: minimal 19 | toolchain: stable 20 | override: true 21 | - run: cargo publish --token ${CRATES_TOKEN} 22 | env: 23 | CRATES_TOKEN: ${{ secrets.CRATES_TOKEN }} 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | **/*.rs.bk 3 | Cargo.lock 4 | .vscode 5 | .idea 6 | *.iml 7 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Simon Bernier St-Pierre "] 3 | edition = "2021" 4 | license = "MPL-2.0" 5 | name = "attohttpc" 6 | version = "0.29.2" 7 | 8 | categories = [ 9 | "network-programming", 10 | "web-programming", 11 | "web-programming::http-client", 12 | ] 13 | description = "Small and lightweight HTTP client" 14 | documentation = "https://docs.rs/attohttpc" 15 | homepage = "https://github.com/sbstp/attohttpc" 16 | keywords = ["http", "https", "client", "request", "response"] 17 | readme = "README.md" 18 | repository = "https://github.com/sbstp/attohttpc" 19 | 20 | [dependencies] 21 | base64 = { version = "0.22.0" } 22 | encoding_rs = { version = "0.8.31", optional = true } 23 | encoding_rs_io = { version = "0.1.7", optional = true } 24 | flate2 = { version = "1.0.24", default-features = false, optional = true } 25 | http = "1" 26 | log = "0.4.17" 27 | mime = { version = "0.3.16", optional = true } 28 | mime_guess = { version = "2.0.5", optional = true } 29 | native-tls = { version = "0.2.10", optional = true } 30 | rand = { version = "0.9.0", optional = true } 31 | rustls-native-certs = { version = "0.8.1", optional = true } 32 | rustls = { version = "0.23.22", optional = true } 33 | serde = { version = "1.0.143", optional = true } 34 | serde_json = { version = "1.0.83", optional = true } 35 | serde_urlencoded = { version = "0.7.1", optional = true } 36 | url = "2.2.2" 37 | webpki-roots = { version = "0.26.8", optional = true } 38 | 39 | [dev-dependencies] 40 | anyhow = "1.0.61" 41 | axum = { version = "0.8.1", features = ["multipart"] } 42 | axum-server = { version = "0.7.1", features = ["tls-rustls"] } 43 | bytes = "1.10.0" 44 | env_logger = "0.11.0" 45 | http-body-util = "0.1.2" 46 | hyper = { version = "1.6.0", features = ["full"] } 47 | hyper-util = "0.1.10" 48 | lazy_static = "1.4.0" 49 | rustls = "0.23.22" 50 | tokio = { version = "1.20.1", features = ["full"] } 51 | tokio-rustls = "0.26.1" 52 | 53 | [features] 54 | basic-auth = [] 55 | charsets = ["encoding_rs", "encoding_rs_io"] 56 | # The following three compress features are mutually exclusive. 57 | compress = ["flate2/default"] 58 | compress-zlib = ["flate2/zlib"] 59 | compress-zlib-ng = ["flate2/zlib-ng"] 60 | default = ["compress", "tls-native"] 61 | form = ["serde", "serde_urlencoded"] 62 | json = ["serde", "serde_json"] 63 | multipart-form = ["mime", "mime_guess", "rand"] 64 | # The following TLS features are mutually exclusive 65 | tls-native = ["native-tls"] 66 | tls-rustls-webpki-roots = ["__rustls", "webpki-roots"] 67 | tls-rustls-native-roots = ["__rustls", "rustls-native-certs"] 68 | # This feature depends on tls-native 69 | tls-native-vendored = ["native-tls/vendored"] 70 | # These features are provided for backwards compatibility 71 | tls = ["tls-native"] 72 | rustls = ["tls-rustls-webpki-roots"] 73 | tls-rustls = ["tls-rustls-webpki-roots"] 74 | tls-vendored = ["tls-native-vendored"] 75 | # Internal feature used to indicate rustls support 76 | __rustls = ["dep:rustls"] 77 | 78 | [package.metadata.docs.rs] 79 | all-features = true 80 | 81 | [[example]] 82 | name = "cat" 83 | path = "examples/cat.rs" 84 | required-features = ["default"] 85 | 86 | [[example]] 87 | name = "imdb" 88 | path = "examples/imdb.rs" 89 | required-features = ["tls-native"] 90 | 91 | [[example]] 92 | name = "nhlapi" 93 | path = "examples/nhlapi.rs" 94 | required-features = ["tls-native"] 95 | 96 | [[example]] 97 | name = "post_json" 98 | path = "examples/post_json.rs" 99 | required-features = ["json"] 100 | 101 | [[example]] 102 | name = "post" 103 | path = "examples/post.rs" 104 | required-features = ["tls-native"] 105 | 106 | [[example]] 107 | name = "charset" 108 | path = "examples/charset.rs" 109 | required-features = ["charsets"] 110 | 111 | [[example]] 112 | name = "multipart" 113 | path = "examples/multipart.rs" 114 | required-features = ["multipart-form"] 115 | 116 | [[test]] 117 | name = "test_invalid_certs" 118 | path = "tests/test_invalid_certs.rs" 119 | required-features = ["tls-native"] 120 | 121 | [[test]] 122 | name = "test_multipart" 123 | path = "tests/test_multipart.rs" 124 | required-features = ["multipart-form"] 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # attohttpc 2 | [Documentation](https://docs.rs/attohttpc) | [Crates.io](https://crates.io/crates/attohttpc) | [Repository](https://github.com/sbstp/attohttpc) 3 | 4 | ## Why attohttpc? 5 | This project's goal is to provide a lightweight and simple HTTP client for the Rust ecosystem. The intended use is for 6 | projects that have HTTP needs where performance is not critical or when HTTP is not the main purpose of the application. 7 | Note that the project still tries to perform well and avoid allocation where possible, but stays away from Rust's 8 | asynchronous stack to provide a crate that's as small as possible. Features are provided behind feature flags when 9 | possible to allow users to get just what they need. Here are the goals of the project: 10 | 11 | * Lightweight 12 | * Secure 13 | * Easy to use 14 | * Modular 15 | * HTTP/1.1 16 | * Use quality crates from the ecosystem (`http`, `url`, `encoding_rs`), not reinventing the wheel. 17 | 18 | ## Features 19 | * `charsets` support for decoding more text encodings than just UTF-8 20 | * `compress` support for decompressing response bodies using `miniz_oxide` (**default**) 21 | * `compress-zlib` support for decompressing response bodies using `zlib` instead of `miniz_oxide` (see [flate2 backends](https://github.com/rust-lang/flate2-rs#backends)) 22 | * `compress-zlib-ng` support for decompressing response bodies using `zlib-ng` instead of `miniz_oxide` (see [flate2 backends](https://github.com/rust-lang/flate2-rs#backends)) 23 | * `json` support for serialization and deserialization 24 | * `form` support for url encoded forms (does not include support for multipart) 25 | * `multipart-form` support for multipart forms (does not include support for url encoding) 26 | * `tls-native` support for tls connections using the `native-tls` crate (**default**) 27 | * `tls-native-vendored` activate the `vendored` feature of `native-tls` 28 | * `tls-rustls-webpki-roots` support for TLS connections using `rustls` instead of `native-tls` with Web PKI roots 29 | * `tls-rustls-native-roots` support for TLS connections using `rustls` with root certificates loaded from the `rustls-native-certs` crate 30 | 31 | ## Usage 32 | See the `examples/` folder in the repository for more use cases. 33 | ```rust 34 | let resp = attohttpc::post("https://my-api.com/do/something").json(&request)?.send()?; 35 | if resp.is_success() { 36 | let response = resp.json()?; 37 | // ... 38 | } 39 | ``` 40 | 41 | ## Current feature set 42 | * Query parameters, Request headers, Bodies, etc. 43 | * TLS, adding trusted certificates, disabling verification, etc. for both `native-tls` and `rustls` 44 | * Automatic redirection 45 | * Streaming response body 46 | * Multiple text encodings 47 | * Automatic compression/decompression with gzip or deflate 48 | * Transfer-Encoding: chunked 49 | * serde/json support 50 | * HTTP Proxies & `HTTP_PROXY`, `HTTPS_PROXY`, `NO_PROXY` environment variables. 51 | * [Happy Eyeballs](https://en.wikipedia.org/wiki/Happy_Eyeballs) 52 | * Authentication (partial support) 53 | 54 | ## License 55 | This project is licensed under the `MPL-2.0`. 56 | -------------------------------------------------------------------------------- /examples/cat.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use attohttpc::Result; 4 | 5 | fn main() -> Result { 6 | env_logger::init(); 7 | 8 | let url: String = env::args().collect::>().into_iter().nth(1).expect("missing url"); 9 | 10 | let resp = attohttpc::get(url).send()?; 11 | println!("Status: {:?}", resp.status()); 12 | println!("Headers:\n{:#?}", resp.headers()); 13 | println!("Body:\n{}", resp.text()?); 14 | 15 | Ok(()) 16 | } 17 | -------------------------------------------------------------------------------- /examples/charset.rs: -------------------------------------------------------------------------------- 1 | fn main() -> Result<(), attohttpc::Error> { 2 | env_logger::init(); 3 | 4 | let resp = attohttpc::get("https://rust-lang.org/").send()?; 5 | println!("{}", resp.text()?); 6 | Ok(()) 7 | } 8 | -------------------------------------------------------------------------------- /examples/head.rs: -------------------------------------------------------------------------------- 1 | fn main() -> attohttpc::Result { 2 | env_logger::init(); 3 | 4 | let resp = attohttpc::head("http://httpbin.org").send()?; 5 | println!("Status: {:?}", resp.status()); 6 | println!("Headers:\n{:#?}", resp.headers()); 7 | 8 | Ok(()) 9 | } 10 | -------------------------------------------------------------------------------- /examples/imdb.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | 3 | fn main() -> attohttpc::Result { 4 | env_logger::init(); 5 | 6 | let resp = attohttpc::get("https://datasets.imdbws.com/title.basics.tsv.gz").send()?; 7 | println!("Status: {:?}", resp.status()); 8 | println!("Headers:\n{:#?}", resp.headers()); 9 | if resp.is_success() { 10 | let file = File::create("title.basics.tsv.gz")?; 11 | let n = resp.write_to(file)?; 12 | println!("Wrote {n} bytes to the file."); 13 | } 14 | 15 | Ok(()) 16 | } 17 | -------------------------------------------------------------------------------- /examples/multipart.rs: -------------------------------------------------------------------------------- 1 | fn main() -> attohttpc::Result { 2 | env_logger::init(); 3 | 4 | let file = attohttpc::MultipartFile::new("file", b"Hello, world!") 5 | .with_type("text/plain")? 6 | .with_filename("hello.txt"); 7 | let form = attohttpc::MultipartBuilder::new() 8 | .with_text("Hello", "world!") 9 | .with_file(file) 10 | .build()?; 11 | 12 | let resp = attohttpc::post("http://httpbin.org/post").body(form).send()?; 13 | 14 | println!("Status: {:?}", resp.status()); 15 | println!("Headers:\n{:#?}", resp.headers()); 16 | println!("Body:\n{}", resp.text()?); 17 | 18 | Ok(()) 19 | } 20 | -------------------------------------------------------------------------------- /examples/nhlapi.rs: -------------------------------------------------------------------------------- 1 | fn main() -> attohttpc::Result { 2 | env_logger::init(); 3 | 4 | let resp = attohttpc::get("https://statsapi.web.nhl.com/api/v1/schedule").send()?; 5 | println!("Status: {:?}", resp.status()); 6 | println!("Headers:\n{:#?}", resp.headers()); 7 | println!("Body:\n{}", resp.text()?); 8 | 9 | Ok(()) 10 | } 11 | -------------------------------------------------------------------------------- /examples/post.rs: -------------------------------------------------------------------------------- 1 | fn main() -> attohttpc::Result { 2 | env_logger::init(); 3 | 4 | let resp = attohttpc::post("https://httpbin.org/post") 5 | .text("hello, world!") 6 | .send()?; 7 | 8 | println!("Status: {:?}", resp.status()); 9 | println!("Headers:\n{:#?}", resp.headers()); 10 | println!("Body:\n{}", resp.text()?); 11 | 12 | Ok(()) 13 | } 14 | -------------------------------------------------------------------------------- /examples/post_json.rs: -------------------------------------------------------------------------------- 1 | use serde_json::json; 2 | 3 | fn main() -> attohttpc::Result { 4 | env_logger::init(); 5 | 6 | let body = json!({ 7 | "hello": "world", 8 | }); 9 | 10 | let resp = attohttpc::post("http://httpbin.org/post").json(&body)?.send()?; 11 | println!("Status: {:?}", resp.status()); 12 | println!("Headers:\n{:#?}", resp.headers()); 13 | println!("Body:\n{}", resp.text_utf8()?); 14 | 15 | Ok(()) 16 | } 17 | -------------------------------------------------------------------------------- /examples/session.rs: -------------------------------------------------------------------------------- 1 | use attohttpc::{Result, Session}; 2 | 3 | fn main() -> Result { 4 | let mut sess = Session::new(); 5 | sess.header("Authorization", "Bearer please let me in!"); 6 | 7 | let resp = sess.get("https://httpbin.org/get").send()?; 8 | println!("{}", resp.text()?); 9 | 10 | Ok(()) 11 | } 12 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 120 2 | -------------------------------------------------------------------------------- /src/charsets.rs: -------------------------------------------------------------------------------- 1 | //! This module is a clean re-export of the `encoding_rs` crate. 2 | //! You can probably find the charset you need in here. 3 | 4 | use encoding_rs::Encoding; 5 | 6 | /// This type is an alias to the `encoding_rs::Encoding` type, used 7 | /// to normalize the name across the crate. 8 | pub type Charset = &'static Encoding; 9 | 10 | pub use encoding_rs::{ 11 | BIG5, EUC_JP, EUC_KR, GB18030, GBK, IBM866, ISO_2022_JP, ISO_8859_10, ISO_8859_13, ISO_8859_14, ISO_8859_15, 12 | ISO_8859_16, ISO_8859_2, ISO_8859_3, ISO_8859_4, ISO_8859_5, ISO_8859_6, ISO_8859_7, ISO_8859_8, ISO_8859_8_I, 13 | KOI8_R, KOI8_U, MACINTOSH, SHIFT_JIS, UTF_16BE, UTF_16LE, UTF_8, WINDOWS_1250, WINDOWS_1251, WINDOWS_1252, 14 | WINDOWS_1253, WINDOWS_1254, WINDOWS_1255, WINDOWS_1256, WINDOWS_1257, WINDOWS_1258, WINDOWS_874, X_MAC_CYRILLIC, 15 | }; 16 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use std::convert::Infallible; 2 | use std::error::Error as StdError; 3 | use std::fmt::{self, Display}; 4 | use std::io; 5 | use std::result; 6 | 7 | /// Errors than can occur while parsing the response from the server. 8 | #[derive(Debug)] 9 | pub enum InvalidResponseKind { 10 | /// Invalid or missing Location header in redirection 11 | LocationHeader, 12 | /// Invalid redirection URL 13 | RedirectionUrl, 14 | /// Status line 15 | StatusLine, 16 | /// Status code 17 | StatusCode, 18 | /// Error parsing header 19 | Header, 20 | /// Error decoding chunk size 21 | ChunkSize, 22 | /// Error decoding chunk 23 | Chunk, 24 | /// Invalid Content-Length header 25 | ContentLength, 26 | } 27 | 28 | impl Display for InvalidResponseKind { 29 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 30 | use InvalidResponseKind::*; 31 | 32 | match self { 33 | LocationHeader => write!(f, "missing or invalid location header"), 34 | RedirectionUrl => write!(f, "invalid redirection url"), 35 | StatusLine => write!(f, "invalid status line"), 36 | StatusCode => write!(f, "invalid status code"), 37 | Header => write!(f, "invalid header"), 38 | ChunkSize => write!(f, "invalid chunk size"), 39 | Chunk => write!(f, "invalid chunk"), 40 | ContentLength => write!(f, "invalid content length"), 41 | } 42 | } 43 | } 44 | 45 | /// Common errors that can occur during HTTP requests. 46 | #[derive(Debug)] 47 | pub enum ErrorKind { 48 | /// CONNECT is not supported. 49 | ConnectNotSupported, 50 | /// Could not connect to proxy with CONNECT method. 51 | ConnectError { 52 | /// Status code from the proxy. 53 | status_code: http::StatusCode, 54 | /// Up to 10 KiB of body data from the proxy which might help diagnose the error. 55 | body: Vec, 56 | }, 57 | /// Error generated by the `http` crate. 58 | Http(http::Error), 59 | /// IO Error 60 | Io(io::Error), 61 | /// Invalid base URL given to the Request. 62 | InvalidBaseUrl, 63 | /// An URL with an invalid host was found while processing the request. 64 | InvalidUrlHost, 65 | /// The URL scheme is unknown and the port is missing. 66 | InvalidUrlPort, 67 | /// Server sent an invalid response. 68 | InvalidResponse(InvalidResponseKind), 69 | /// Too many redirections 70 | TooManyRedirections, 71 | /// Status code indicates failure 72 | StatusCode(http::StatusCode), 73 | /// JSON decoding/encoding error. 74 | #[cfg(feature = "json")] 75 | Json(serde_json::Error), 76 | /// Form-URL encoding error. 77 | #[cfg(feature = "form")] 78 | UrlEncoded(serde_urlencoded::ser::Error), 79 | /// TLS error encountered while connecting to an https server. 80 | #[cfg(feature = "tls-native")] 81 | Tls(native_tls::Error), 82 | /// TLS error encountered while connecting to an https server. 83 | #[cfg(all(feature = "__rustls", not(feature = "tls-native")))] 84 | Tls(rustls::Error), 85 | /// Invalid DNS name used for TLS certificate verification 86 | #[cfg(feature = "__rustls")] 87 | InvalidDNSName(String), 88 | /// Invalid mime type in a Multipart form 89 | InvalidMimeType(String), 90 | /// TLS was not enabled by features. 91 | TlsDisabled, 92 | /// Empty cert store 93 | #[cfg(all(feature = "__rustls", not(feature = "tls-native")))] 94 | ServerCertVerifier(rustls::client::VerifierBuilderError), 95 | } 96 | 97 | /// A type that contains all the errors that can possibly occur while accessing an HTTP server. 98 | #[derive(Debug)] 99 | pub struct Error(pub(crate) Box); 100 | 101 | impl Error { 102 | /// Get a reference to the `ErrorKind` inside. 103 | pub fn kind(&self) -> &ErrorKind { 104 | &self.0 105 | } 106 | 107 | /// Comsume this `Error` and get the `ErrorKind` inside. 108 | pub fn into_kind(self) -> ErrorKind { 109 | *self.0 110 | } 111 | } 112 | 113 | impl Display for Error { 114 | fn fmt(&self, w: &mut fmt::Formatter) -> fmt::Result { 115 | use ErrorKind::*; 116 | 117 | match *self.0 { 118 | ConnectNotSupported => write!(w, "CONNECT is not supported"), 119 | ConnectError { status_code, .. } => write!(w, "Proxy CONNECT error: {status_code}"), 120 | Http(ref e) => write!(w, "Http Error: {e}"), 121 | Io(ref e) => write!(w, "Io Error: {e}"), 122 | InvalidBaseUrl => write!(w, "Invalid base URL"), 123 | InvalidUrlHost => write!(w, "URL is missing a host"), 124 | InvalidUrlPort => write!(w, "URL is missing a port"), 125 | InvalidResponse(ref k) => write!(w, "InvalidResponse: {k}"), 126 | TooManyRedirections => write!(w, "Too many redirections"), 127 | StatusCode(ref sc) => write!(w, "Status code {sc} indicates failure"), 128 | #[cfg(feature = "json")] 129 | Json(ref e) => write!(w, "Json Error: {e}"), 130 | #[cfg(feature = "form")] 131 | UrlEncoded(ref e) => write!(w, "URL Encoding Error: {e}"), 132 | #[cfg(any(feature = "tls-native", feature = "__rustls"))] 133 | Tls(ref e) => write!(w, "Tls Error: {e}"), 134 | #[cfg(feature = "__rustls")] 135 | InvalidDNSName(ref e) => write!(w, "Invalid DNS name: {e}"), 136 | InvalidMimeType(ref e) => write!(w, "Invalid mime type: {e}"), 137 | TlsDisabled => write!(w, "TLS is disabled, activate one of the tls- features"), 138 | #[cfg(all(feature = "__rustls", not(feature = "tls-native")))] 139 | ServerCertVerifier(ref e) => write!(w, "Invalid certificate: {e}"), 140 | } 141 | } 142 | } 143 | 144 | impl StdError for Error { 145 | fn cause(&self) -> Option<&dyn StdError> { 146 | use ErrorKind::*; 147 | 148 | match *self.0 { 149 | Io(ref e) => Some(e), 150 | Http(ref e) => Some(e), 151 | #[cfg(feature = "json")] 152 | Json(ref e) => Some(e), 153 | #[cfg(any(feature = "tls-native", feature = "__rustls"))] 154 | Tls(ref e) => Some(e), 155 | _ => None, 156 | } 157 | } 158 | } 159 | 160 | impl From for Error { 161 | fn from(_err: Infallible) -> Error { 162 | unreachable!() 163 | } 164 | } 165 | 166 | impl From for Error { 167 | fn from(err: io::Error) -> Error { 168 | Error(Box::new(ErrorKind::Io(err))) 169 | } 170 | } 171 | 172 | impl From for Error { 173 | fn from(err: http::Error) -> Error { 174 | Error(Box::new(ErrorKind::Http(err))) 175 | } 176 | } 177 | 178 | impl From for Error { 179 | fn from(err: http::header::InvalidHeaderValue) -> Error { 180 | Error(Box::new(ErrorKind::Http(http::Error::from(err)))) 181 | } 182 | } 183 | 184 | #[cfg(feature = "tls-native")] 185 | impl From for Error { 186 | fn from(err: native_tls::Error) -> Error { 187 | Error(Box::new(ErrorKind::Tls(err))) 188 | } 189 | } 190 | 191 | #[cfg(all(feature = "__rustls", not(feature = "tls-native")))] 192 | impl From for Error { 193 | fn from(err: rustls::Error) -> Error { 194 | Error(Box::new(ErrorKind::Tls(err))) 195 | } 196 | } 197 | 198 | #[cfg(feature = "json")] 199 | impl From for Error { 200 | fn from(err: serde_json::Error) -> Error { 201 | Error(Box::new(ErrorKind::Json(err))) 202 | } 203 | } 204 | 205 | #[cfg(feature = "form")] 206 | impl From for Error { 207 | fn from(err: serde_urlencoded::ser::Error) -> Error { 208 | Error(Box::new(ErrorKind::UrlEncoded(err))) 209 | } 210 | } 211 | 212 | impl From for Error { 213 | fn from(err: ErrorKind) -> Error { 214 | Error(Box::new(err)) 215 | } 216 | } 217 | 218 | impl From for Error { 219 | fn from(kind: InvalidResponseKind) -> Error { 220 | ErrorKind::InvalidResponse(kind).into() 221 | } 222 | } 223 | 224 | #[cfg(all(feature = "__rustls", not(feature = "tls-native")))] 225 | impl From for Error { 226 | fn from(err: rustls::client::VerifierBuilderError) -> Error { 227 | Error(Box::new(ErrorKind::ServerCertVerifier(err))) 228 | } 229 | } 230 | 231 | impl From for io::Error { 232 | fn from(err: Error) -> io::Error { 233 | io::Error::new(io::ErrorKind::Other, err) 234 | } 235 | } 236 | 237 | impl From for io::Error { 238 | fn from(kind: InvalidResponseKind) -> io::Error { 239 | io::Error::new(io::ErrorKind::Other, Error(Box::new(ErrorKind::InvalidResponse(kind)))) 240 | } 241 | } 242 | 243 | /// Wrapper for the `Result` type with an `Error`. 244 | pub type Result = result::Result; 245 | -------------------------------------------------------------------------------- /src/happy.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::iter::{self, FusedIterator}; 3 | use std::net::{IpAddr, TcpStream, ToSocketAddrs}; 4 | use std::sync::mpsc::channel; 5 | use std::thread; 6 | use std::time::{Duration, Instant}; 7 | 8 | use url::Host; 9 | 10 | const RACE_DELAY: Duration = Duration::from_millis(200); 11 | 12 | /// This function implements a basic form of the happy eyeballs RFC to quickly connect 13 | /// to a domain which is available in both IPv4 and IPv6. Connection attempts are raced 14 | /// against each other and the first to connect successfully wins the race. 15 | pub fn connect(host: &Host<&str>, port: u16, timeout: Duration, deadline: Option) -> io::Result { 16 | let addrs: Vec<_> = match *host { 17 | Host::Domain(domain) => (domain, port).to_socket_addrs()?.collect(), 18 | Host::Ipv4(ip) => return TcpStream::connect_timeout(&(IpAddr::V4(ip), port).into(), timeout), 19 | Host::Ipv6(ip) => return TcpStream::connect_timeout(&(IpAddr::V6(ip), port).into(), timeout), 20 | }; 21 | 22 | if let [addr] = &addrs[..] { 23 | debug!("DNS returned only one address, using fast path"); 24 | return TcpStream::connect_timeout(addr, timeout); 25 | } 26 | 27 | let ipv4 = addrs.iter().filter(|a| a.is_ipv4()); 28 | let ipv6 = addrs.iter().filter(|a| a.is_ipv6()); 29 | let sorted = intertwine(ipv6, ipv4); 30 | 31 | let (tx, rx) = channel(); 32 | let mut first_err = None; 33 | 34 | let start = Instant::now(); 35 | 36 | let mut handle_res = |addr, res| match res { 37 | Ok(sock) => { 38 | debug!( 39 | "successfully connected to {}, took {}ms", 40 | addr, 41 | start.elapsed().as_millis() 42 | ); 43 | 44 | Some(sock) 45 | } 46 | Err(err) => { 47 | debug!("failed to connect to {}: {}", addr, err); 48 | 49 | if first_err.is_none() { 50 | first_err = Some(err); 51 | } 52 | 53 | None 54 | } 55 | }; 56 | 57 | // This loop will race each connection attempt against others, returning early if a 58 | // connection attempt is successful. 59 | for &addr in sorted { 60 | let tx = tx.clone(); 61 | 62 | thread::spawn(move || { 63 | debug!("trying to connect to {}", addr); 64 | 65 | let res = match deadline.map(|deadline| deadline.checked_duration_since(Instant::now())) { 66 | None => TcpStream::connect_timeout(&addr, timeout), 67 | Some(Some(timeout1)) => TcpStream::connect_timeout(&addr, timeout.min(timeout1)), 68 | Some(None) => Err(io::ErrorKind::TimedOut.into()), 69 | }; 70 | 71 | let _ = tx.send((addr, res)); 72 | }); 73 | 74 | if let Ok((addr, res)) = rx.recv_timeout(RACE_DELAY) { 75 | if let Some(sock) = handle_res(addr, res) { 76 | return Ok(sock); 77 | } 78 | } 79 | } 80 | 81 | // We must drop this handle to the sender in order to properly disconnect the channel 82 | // when all the threads are finished. 83 | drop(tx); 84 | 85 | // This loop waits for replies from the background threads. It will automatically timeout 86 | // when the background threads' connection attempts timeout and the senders are dropped. 87 | // This loop is reached when some of the threads do not complete within the race delay. 88 | for (addr, res) in rx.iter() { 89 | if let Some(sock) = handle_res(addr, res) { 90 | return Ok(sock); 91 | } 92 | } 93 | 94 | debug!( 95 | "could not connect to any address, took {}ms", 96 | start.elapsed().as_millis() 97 | ); 98 | 99 | Err(first_err.unwrap_or_else(|| io::Error::new(io::ErrorKind::Other, "no DNS entries found"))) 100 | } 101 | 102 | fn intertwine(mut ita: A, mut itb: B) -> impl Iterator 103 | where 104 | A: FusedIterator, 105 | B: FusedIterator, 106 | { 107 | let mut valb = None; 108 | 109 | iter::from_fn(move || { 110 | if let Some(b) = valb.take() { 111 | return Some(b); 112 | } 113 | 114 | match (ita.next(), itb.next()) { 115 | (Some(a), Some(b)) => { 116 | valb = Some(b); 117 | Some(a) 118 | } 119 | (Some(a), None) => Some(a), 120 | (None, Some(b)) => Some(b), 121 | (None, None) => None, 122 | } 123 | }) 124 | } 125 | 126 | #[test] 127 | fn test_intertwine_even() { 128 | let x: Vec = intertwine(vec![1, 2, 3].into_iter(), vec![4, 5, 6].into_iter()).collect(); 129 | assert_eq!(&x[..], &[1, 4, 2, 5, 3, 6][..]); 130 | } 131 | 132 | #[test] 133 | fn test_intertwine_left() { 134 | let x: Vec = intertwine(vec![1, 2, 3, 100, 101].into_iter(), vec![4, 5, 6].into_iter()).collect(); 135 | assert_eq!(&x[..], &[1, 4, 2, 5, 3, 6, 100, 101][..]); 136 | } 137 | 138 | #[test] 139 | fn test_intertwine_right() { 140 | let x: Vec = intertwine(vec![1, 2, 3].into_iter(), vec![4, 5, 6, 100, 101].into_iter()).collect(); 141 | assert_eq!(&x[..], &[1, 4, 2, 5, 3, 6, 100, 101][..]); 142 | } 143 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(missing_debug_implementations)] 2 | #![deny(missing_docs)] 3 | #![allow(clippy::needless_doctest_main)] 4 | //! This project's goal is to provide a lightweight and simple HTTP client for the Rust ecosystem. The intended use is for 5 | //! projects that have HTTP needs where performance is not critical or when HTTP is not the main purpose of the application. 6 | //! Note that the project still tries to perform well and avoid allocation where possible, but stays away from Rust's 7 | //! asynchronous stack to provide a crate that's as small as possible. Features are provided behind feature flags when 8 | //! possible to allow users to get just what they need. 9 | //! 10 | //! Check out the [repository](https://github.com/sbstp/attohttpc) for more information and examples. 11 | //! 12 | //! # Quick start 13 | //! ```no_run 14 | //! # #[cfg(feature = "json")] 15 | //! # use serde_json::json; 16 | //! # #[cfg(feature = "json")] 17 | //! # fn main() -> attohttpc::Result { 18 | //! let obj = json!({ 19 | //! "hello": "world", 20 | //! }); 21 | //! 22 | //! let resp = attohttpc::post("https://my-api.org/do/something") 23 | //! .header("X-My-Header", "foo") // set a header for the request 24 | //! .param("qux", "baz") // set a query parameter 25 | //! .json(&obj)? // set the request body (json feature required) 26 | //! .send()?; // send the request 27 | //! 28 | //! // Check if the status is a 2XX code. 29 | //! if resp.is_success() { 30 | //! // Consume the response body as text and print it. 31 | //! println!("{}", resp.text()?); 32 | //! } 33 | //! # Ok(()) 34 | //! # } 35 | //! # #[cfg(not(feature = "json"))] 36 | //! # fn main() { 37 | //! # } 38 | //! ``` 39 | //! 40 | //! # Features 41 | //! * `basic-auth` support for basic auth 42 | //! * `charsets` support for decoding more text encodings than just UTF-8 43 | //! * `compress` support for decompressing response bodies using `miniz_oxide` (**default**) 44 | //! * `compress-zlib` support for decompressing response bodies using `zlib` instead of `miniz_oxide` 45 | //! (see [flate2 backends](https://github.com/rust-lang/flate2-rs#backends)) 46 | //! * `compress-zlib-ng` support for decompressing response bodies using `zlib-ng` instead of `miniz_oxide` 47 | //! (see [flate2 backends](https://github.com/rust-lang/flate2-rs#backends)) 48 | //! * `json` support for serialization and deserialization 49 | //! * `form` support for url encoded forms (does not include support for multipart) 50 | //! * `multipart-form` support for multipart forms (does not include support for url encoding) 51 | //! * `tls-native` support for tls connections using the `native-tls` crate (**default**) 52 | //! * `tls-native-vendored` activate the `vendored` feature of `native-tls` 53 | //! * `tls-rustls-webpki-roots` support for TLS connections using `rustls` instead of `native-tls` with Web PKI roots 54 | //! * `tls-rustls-native-roots` support for TLS connections using `rustls` with root certificates loaded from the `rustls-native-certs` crate 55 | //! 56 | //! # Activating a feature 57 | //! To activate a feature, specify it in your `Cargo.toml` file like so 58 | //! ```toml 59 | //! attohttpc = { version = "...", features = ["json", "form", ...] } 60 | //! ``` 61 | //! 62 | 63 | macro_rules! debug { 64 | ($($arg:tt)+) => { log::debug!(target: "attohttpc", $($arg)+) }; 65 | } 66 | 67 | macro_rules! warn { 68 | ($($arg:tt)+) => { log::warn!(target: "attohttpc", $($arg)+) }; 69 | } 70 | 71 | #[cfg(feature = "charsets")] 72 | pub mod charsets; 73 | mod error; 74 | mod happy; 75 | #[cfg(feature = "multipart-form")] 76 | mod multipart; 77 | #[cfg(feature = "multipart-form")] 78 | mod multipart_crate; 79 | mod parsing; 80 | mod request; 81 | mod streams; 82 | mod tls; 83 | 84 | pub use crate::error::{Error, ErrorKind, InvalidResponseKind, Result}; 85 | #[cfg(feature = "multipart-form")] 86 | pub use crate::multipart::{Multipart, MultipartBuilder, MultipartFile}; 87 | pub use crate::parsing::{Response, ResponseReader}; 88 | pub use crate::request::proxy::{ProxySettings, ProxySettingsBuilder}; 89 | pub use crate::request::{body, PreparedRequest, RequestBuilder, RequestInspector, Session}; 90 | #[cfg(feature = "charsets")] 91 | pub use crate::{charsets::Charset, parsing::TextReader}; 92 | pub use http::Method; 93 | pub use http::StatusCode; 94 | 95 | pub mod header { 96 | //! This module is a re-export of the `http` crate's `header` module. 97 | pub use http::header::*; 98 | } 99 | 100 | /// Create a new `RequestBuilder` with the GET method. 101 | pub fn get(base_url: U) -> RequestBuilder 102 | where 103 | U: AsRef, 104 | { 105 | RequestBuilder::new(Method::GET, base_url) 106 | } 107 | 108 | /// Create a new `RequestBuilder` with the POST method. 109 | pub fn post(base_url: U) -> RequestBuilder 110 | where 111 | U: AsRef, 112 | { 113 | RequestBuilder::new(Method::POST, base_url) 114 | } 115 | 116 | /// Create a new `RequestBuilder` with the PUT method. 117 | pub fn put(base_url: U) -> RequestBuilder 118 | where 119 | U: AsRef, 120 | { 121 | RequestBuilder::new(Method::PUT, base_url) 122 | } 123 | 124 | /// Create a new `RequestBuilder` with the DELETE method. 125 | pub fn delete(base_url: U) -> RequestBuilder 126 | where 127 | U: AsRef, 128 | { 129 | RequestBuilder::new(Method::DELETE, base_url) 130 | } 131 | 132 | /// Create a new `RequestBuilder` with the HEAD method. 133 | pub fn head(base_url: U) -> RequestBuilder 134 | where 135 | U: AsRef, 136 | { 137 | RequestBuilder::new(Method::HEAD, base_url) 138 | } 139 | 140 | /// Create a new `RequestBuilder` with the OPTIONS method. 141 | pub fn options(base_url: U) -> RequestBuilder 142 | where 143 | U: AsRef, 144 | { 145 | RequestBuilder::new(Method::OPTIONS, base_url) 146 | } 147 | 148 | /// Create a new `RequestBuilder` with the PATCH method. 149 | pub fn patch(base_url: U) -> RequestBuilder 150 | where 151 | U: AsRef, 152 | { 153 | RequestBuilder::new(Method::PATCH, base_url) 154 | } 155 | 156 | /// Create a new `RequestBuilder` with the TRACE method. 157 | pub fn trace(base_url: U) -> RequestBuilder 158 | where 159 | U: AsRef, 160 | { 161 | RequestBuilder::new(Method::TRACE, base_url) 162 | } 163 | 164 | mod skip_debug { 165 | use std::fmt; 166 | 167 | #[derive(Clone)] 168 | pub struct SkipDebug(pub T); 169 | 170 | impl fmt::Debug for SkipDebug { 171 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 172 | write!(f, "...") 173 | } 174 | } 175 | 176 | impl From for SkipDebug { 177 | fn from(val: T) -> SkipDebug { 178 | SkipDebug(val) 179 | } 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /src/multipart.rs: -------------------------------------------------------------------------------- 1 | use super::body::{Body, BodyKind}; 2 | use super::{Error, ErrorKind, Result}; 3 | use mime::Mime; 4 | use std::fmt; 5 | use std::io::{copy, prelude::*, Cursor, Error as IoError, Result as IoResult}; 6 | 7 | /// A file to be uploaded as part of a multipart form. 8 | #[derive(Debug, Clone)] 9 | pub struct MultipartFile<'key, 'data> { 10 | name: &'key str, 11 | file: &'data [u8], 12 | filename: Option<&'key str>, 13 | mime: Option, 14 | } 15 | 16 | impl<'key, 'data> MultipartFile<'key, 'data> { 17 | /// Constructs a new `MultipartFile` from the name and contents. 18 | pub fn new(name: &'key str, file: &'data [u8]) -> Self { 19 | Self { 20 | name, 21 | file, 22 | filename: None, 23 | mime: None, 24 | } 25 | } 26 | 27 | /// Sets the MIME type of the file. 28 | /// 29 | /// # Errors 30 | /// Returns an error if the MIME type is invalid. 31 | pub fn with_type(self, mime_type: impl AsRef) -> Result { 32 | let mime_str = mime_type.as_ref(); 33 | let mime: Mime = match mime_str.parse() { 34 | Ok(mime) => mime, 35 | Err(error) => return Err(Error(Box::new(ErrorKind::InvalidMimeType(error.to_string())))), 36 | }; 37 | Ok(Self { 38 | mime: Some(mime), 39 | ..self 40 | }) 41 | } 42 | 43 | /// Sets the filename of the file. 44 | pub fn with_filename(self, filename: &'key str) -> Self { 45 | Self { 46 | filename: Some(filename), 47 | ..self 48 | } 49 | } 50 | } 51 | 52 | /// A builder for creating a `Multipart` body. 53 | #[derive(Debug, Clone, Default)] 54 | pub struct MultipartBuilder<'key, 'data> { 55 | text: Vec<(&'key str, &'data str)>, 56 | files: Vec>, 57 | } 58 | 59 | impl<'key, 'data> MultipartBuilder<'key, 'data> { 60 | /// Creates a new `MultipartBuilder`. 61 | pub fn new() -> Self { 62 | Self::default() 63 | } 64 | 65 | /// Adds a text field to the form. 66 | pub fn with_text(mut self, name: &'key str, text: &'data str) -> Self { 67 | self.text.push((name, text)); 68 | self 69 | } 70 | 71 | /// Adds a `MultipartFile` to the form. 72 | pub fn with_file(mut self, file: MultipartFile<'key, 'data>) -> Self { 73 | self.files.push(file); 74 | self 75 | } 76 | 77 | /// Creates a `Multipart` to be used as a body. 78 | pub fn build(self) -> Result> { 79 | let mut mp = crate::multipart_crate::lazy::Multipart::new(); 80 | for (k, v) in self.text { 81 | mp.add_text(k, v); 82 | } 83 | for file in self.files { 84 | mp.add_stream(file.name, Cursor::new(file.file), file.filename, file.mime); 85 | } 86 | let prepared = mp.prepare().map_err::(Into::into)?; 87 | Ok(Multipart { data: prepared }) 88 | } 89 | } 90 | 91 | /// A multipart form created using `MultipartBuilder`. 92 | pub struct Multipart<'data> { 93 | data: crate::multipart_crate::lazy::PreparedFields<'data>, 94 | } 95 | 96 | impl Body for Multipart<'_> { 97 | fn kind(&mut self) -> IoResult { 98 | Ok(BodyKind::Chunked) 99 | } 100 | 101 | fn write(&mut self, mut writer: W) -> IoResult<()> { 102 | copy(&mut self.data, &mut writer)?; 103 | Ok(()) 104 | } 105 | 106 | fn content_type(&mut self) -> IoResult> { 107 | Ok(Some(format!("multipart/form-data; boundary={}", self.data.boundary()))) 108 | } 109 | } 110 | 111 | impl fmt::Debug for Multipart<'_> { 112 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 113 | f.debug_struct("Multipart").finish() 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/multipart_crate/lazy.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2016 `multipart` Crate Developers 2 | // Copyright 2025 Simon Bernier St-Pierre 3 | // 4 | // Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be 7 | // copied, modified, or distributed except according to those terms. 8 | //! The client-side abstraction for multipart requests. Enabled with the `client` feature. 9 | //! 10 | //! Multipart requests which write out their data in one fell swoop. 11 | #![allow(dead_code)] 12 | 13 | use mime::Mime; 14 | 15 | use std::borrow::Cow; 16 | use std::error::Error; 17 | use std::fs::File; 18 | use std::path::{Path, PathBuf}; 19 | 20 | use std::io::prelude::*; 21 | use std::io::Cursor; 22 | use std::{fmt, io}; 23 | 24 | use super::{HttpRequest, HttpStream}; 25 | 26 | macro_rules! try_lazy ( 27 | ($field:expr, $try:expr) => ( 28 | match $try { 29 | Ok(ok) => ok, 30 | Err(e) => return Err(LazyError::with_field($field.into(), e)), 31 | } 32 | ); 33 | ($try:expr) => ( 34 | match $try { 35 | Ok(ok) => ok, 36 | Err(e) => return Err(LazyError::without_field(e)), 37 | } 38 | ) 39 | ); 40 | 41 | /// A `LazyError` wrapping `std::io::Error`. 42 | pub type LazyIoError<'a> = LazyError<'a, io::Error>; 43 | 44 | /// `Result` type for `LazyIoError`. 45 | pub type LazyIoResult<'a, T> = Result>; 46 | 47 | /// An error for lazily written multipart requests, including the original error as well 48 | /// as the field which caused the error, if applicable. 49 | pub struct LazyError<'a, E> { 50 | /// The field that caused the error. 51 | /// If `None`, there was a problem opening the stream to write or finalizing the stream. 52 | pub field_name: Option>, 53 | /// The inner error. 54 | pub error: E, 55 | /// Private field for back-compat. 56 | _priv: (), 57 | } 58 | 59 | impl<'a, E> LazyError<'a, E> { 60 | fn without_field>(error: E_) -> Self { 61 | LazyError { 62 | field_name: None, 63 | error: error.into(), 64 | _priv: (), 65 | } 66 | } 67 | 68 | fn with_field>(field_name: Cow<'a, str>, error: E_) -> Self { 69 | LazyError { 70 | field_name: Some(field_name), 71 | error: error.into(), 72 | _priv: (), 73 | } 74 | } 75 | 76 | fn transform_err>(self) -> LazyError<'a, E_> { 77 | LazyError { 78 | field_name: self.field_name, 79 | error: self.error.into(), 80 | _priv: (), 81 | } 82 | } 83 | } 84 | 85 | /// Take `self.error`, discarding `self.field_name`. 86 | impl<'a> From> for io::Error { 87 | fn from(val: LazyError<'a, io::Error>) -> Self { 88 | val.error 89 | } 90 | } 91 | 92 | impl Error for LazyError<'_, E> { 93 | fn cause(&self) -> Option<&dyn Error> { 94 | Some(&self.error) 95 | } 96 | } 97 | 98 | impl fmt::Debug for LazyError<'_, E> { 99 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { 100 | if let Some(ref field_name) = self.field_name { 101 | fmt.write_fmt(format_args!("LazyError (on field {:?}): {:?}", field_name, self.error)) 102 | } else { 103 | fmt.write_fmt(format_args!("LazyError (misc): {:?}", self.error)) 104 | } 105 | } 106 | } 107 | 108 | impl fmt::Display for LazyError<'_, E> { 109 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { 110 | if let Some(ref field_name) = self.field_name { 111 | fmt.write_fmt(format_args!("Error writing field {:?}: {}", field_name, self.error)) 112 | } else { 113 | fmt.write_fmt(format_args!("Error opening or flushing stream: {}", self.error)) 114 | } 115 | } 116 | } 117 | 118 | /// A multipart request which writes all fields at once upon being provided an output stream. 119 | /// 120 | /// Sacrifices static dispatch for support for dynamic construction. Reusable. 121 | /// 122 | /// #### Lifetimes 123 | /// * `'n`: Lifetime for field **n**ames; will only escape this struct in `LazyIoError<'n>`. 124 | /// * `'d`: Lifetime for **d**ata: will only escape this struct in `PreparedFields<'d>`. 125 | #[derive(Debug, Default)] 126 | pub struct Multipart<'n, 'd> { 127 | fields: Vec>, 128 | } 129 | 130 | impl<'n, 'd> Multipart<'n, 'd> { 131 | /// Initialize a new lazy dynamic request. 132 | pub fn new() -> Self { 133 | Default::default() 134 | } 135 | 136 | /// Add a text field to this request. 137 | pub fn add_text(&mut self, name: N, text: T) -> &mut Self 138 | where 139 | N: Into>, 140 | T: Into>, 141 | { 142 | self.fields.push(Field { 143 | name: name.into(), 144 | data: Data::Text(text.into()), 145 | }); 146 | 147 | self 148 | } 149 | 150 | /// Add a file field to this request. 151 | /// 152 | /// ### Note 153 | /// Does not check if `path` exists. 154 | pub fn add_file(&mut self, name: N, path: P) -> &mut Self 155 | where 156 | N: Into>, 157 | P: IntoCowPath<'d>, 158 | { 159 | self.fields.push(Field { 160 | name: name.into(), 161 | data: Data::File(path.into_cow_path()), 162 | }); 163 | 164 | self 165 | } 166 | 167 | /// Add a generic stream field to this request, 168 | pub fn add_stream(&mut self, name: N, stream: R, filename: Option, mime: Option) -> &mut Self 169 | where 170 | N: Into>, 171 | R: Read + 'd, 172 | F: Into>, 173 | { 174 | self.fields.push(Field { 175 | name: name.into(), 176 | data: Data::Stream(Stream { 177 | content_type: mime.unwrap_or(mime::APPLICATION_OCTET_STREAM), 178 | filename: filename.map(|f| f.into()), 179 | stream: Box::new(stream), 180 | }), 181 | }); 182 | 183 | self 184 | } 185 | 186 | /// Convert `req` to `HttpStream`, write out the fields in this request, and finish the 187 | /// request, returning the response if successful, or the first error encountered. 188 | /// 189 | /// If any files were added by path they will now be opened for reading. 190 | pub fn send( 191 | &mut self, 192 | mut req: R, 193 | ) -> Result<::Response, LazyError<'n, ::Error>> { 194 | let mut prepared = self.prepare().map_err(LazyError::transform_err)?; 195 | 196 | req.apply_headers(prepared.boundary(), prepared.content_len()); 197 | 198 | let mut stream = try_lazy!(req.open_stream()); 199 | 200 | try_lazy!(io::copy(&mut prepared, &mut stream)); 201 | 202 | stream.finish().map_err(LazyError::without_field) 203 | } 204 | 205 | /// Export the multipart data contained in this lazy request as an adaptor which implements `Read`. 206 | /// 207 | /// During this step, if any files were added by path then they will be opened for reading 208 | /// and their length measured. 209 | pub fn prepare(&mut self) -> LazyIoResult<'n, PreparedFields<'d>> { 210 | PreparedFields::from_fields(&mut self.fields) 211 | } 212 | } 213 | 214 | #[derive(Debug)] 215 | struct Field<'n, 'd> { 216 | name: Cow<'n, str>, 217 | data: Data<'n, 'd>, 218 | } 219 | 220 | enum Data<'n, 'd> { 221 | Text(Cow<'d, str>), 222 | File(Cow<'d, Path>), 223 | Stream(Stream<'n, 'd>), 224 | } 225 | 226 | impl fmt::Debug for Data<'_, '_> { 227 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 228 | match *self { 229 | Data::Text(ref text) => write!(f, "Data::Text({:?})", text), 230 | Data::File(ref path) => write!(f, "Data::File({:?})", path), 231 | Data::Stream(_) => f.write_str("Data::Stream(Box)"), 232 | } 233 | } 234 | } 235 | 236 | struct Stream<'n, 'd> { 237 | filename: Option>, 238 | content_type: Mime, 239 | stream: Box, 240 | } 241 | 242 | /// The result of [`Multipart::prepare()`](struct.Multipart.html#method.prepare). 243 | /// 244 | /// Implements `Read`, contains the entire request body. 245 | /// 246 | /// Individual files/streams are dropped as they are read to completion. 247 | /// 248 | /// ### Note 249 | /// The fields in the request may have been reordered to simplify the preparation step. 250 | /// No compliant server implementation will be relying on the specific ordering of fields anyways. 251 | pub struct PreparedFields<'d> { 252 | text_data: Cursor>, 253 | streams: Vec>, 254 | end_boundary: Cursor, 255 | content_len: Option, 256 | } 257 | 258 | impl<'d> PreparedFields<'d> { 259 | fn from_fields<'n>(fields: &mut Vec>) -> Result> { 260 | debug!("Field count: {}", fields.len()); 261 | 262 | // One of the two RFCs specifies that any bytes before the first boundary are to be 263 | // ignored anyway 264 | let mut boundary = format!("\r\n--{}", super::gen_boundary()); 265 | 266 | let mut text_data = Vec::new(); 267 | let mut streams = Vec::new(); 268 | let mut content_len = 0u64; 269 | let mut use_len = true; 270 | 271 | for field in fields.drain(..) { 272 | match field.data { 273 | Data::Text(text) => write!( 274 | text_data, 275 | "{}\r\nContent-Disposition: form-data; \ 276 | name=\"{}\"\r\n\r\n{}", 277 | boundary, field.name, text 278 | ) 279 | .unwrap(), 280 | Data::File(file) => { 281 | let (stream, len) = PreparedField::from_path(field.name, &file, &boundary)?; 282 | content_len += len; 283 | streams.push(stream); 284 | } 285 | Data::Stream(stream) => { 286 | use_len = false; 287 | 288 | streams.push(PreparedField::from_stream( 289 | &field.name, 290 | &boundary, 291 | &stream.content_type, 292 | stream.filename.as_deref(), 293 | stream.stream, 294 | )); 295 | } 296 | } 297 | } 298 | 299 | // So we don't write a spurious end boundary 300 | if text_data.is_empty() && streams.is_empty() { 301 | boundary = String::new(); 302 | } else { 303 | boundary.push_str("--"); 304 | } 305 | 306 | content_len += boundary.len() as u64; 307 | 308 | Ok(PreparedFields { 309 | text_data: Cursor::new(text_data), 310 | streams, 311 | end_boundary: Cursor::new(boundary), 312 | content_len: if use_len { Some(content_len) } else { None }, 313 | }) 314 | } 315 | 316 | /// Get the content-length value for this set of fields, if applicable (all fields are sized, 317 | /// i.e. not generic streams). 318 | pub fn content_len(&self) -> Option { 319 | self.content_len 320 | } 321 | 322 | /// Get the boundary that was used to serialize the request. 323 | pub fn boundary(&self) -> &str { 324 | let boundary = self.end_boundary.get_ref(); 325 | 326 | // Get just the bare boundary string 327 | &boundary[4..boundary.len() - 2] 328 | } 329 | } 330 | 331 | impl Read for PreparedFields<'_> { 332 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 333 | if buf.is_empty() { 334 | debug!("PreparedFields::read() was passed a zero-sized buffer."); 335 | return Ok(0); 336 | } 337 | 338 | let mut total_read = 0; 339 | 340 | while total_read < buf.len() && !cursor_at_end(&self.end_boundary) { 341 | let buf = &mut buf[total_read..]; 342 | 343 | total_read += if !cursor_at_end(&self.text_data) { 344 | self.text_data.read(buf)? 345 | } else if let Some(mut field) = self.streams.pop() { 346 | match field.read(buf) { 347 | Ok(0) => continue, 348 | res => { 349 | self.streams.push(field); 350 | res 351 | } 352 | }? 353 | } else { 354 | self.end_boundary.read(buf)? 355 | }; 356 | } 357 | 358 | Ok(total_read) 359 | } 360 | } 361 | 362 | struct PreparedField<'d> { 363 | header: Cursor>, 364 | stream: Box, 365 | } 366 | 367 | impl<'d> PreparedField<'d> { 368 | fn from_path<'n>(name: Cow<'n, str>, path: &Path, boundary: &str) -> Result<(Self, u64), LazyIoError<'n>> { 369 | let (content_type, filename) = super::mime_filename(path); 370 | 371 | let file = try_lazy!(name, File::open(path)); 372 | let content_len = try_lazy!(name, file.metadata()).len(); 373 | 374 | let stream = Self::from_stream(&name, boundary, &content_type, filename, Box::new(file)); 375 | 376 | let content_len = content_len + (stream.header.get_ref().len() as u64); 377 | 378 | Ok((stream, content_len)) 379 | } 380 | 381 | fn from_stream( 382 | name: &str, 383 | boundary: &str, 384 | content_type: &Mime, 385 | filename: Option<&str>, 386 | stream: Box, 387 | ) -> Self { 388 | let mut header = Vec::new(); 389 | 390 | write!( 391 | header, 392 | "{}\r\nContent-Disposition: form-data; name=\"{}\"", 393 | boundary, name 394 | ) 395 | .unwrap(); 396 | 397 | if let Some(filename) = filename { 398 | write!(header, "; filename=\"{}\"", filename).unwrap(); 399 | } 400 | 401 | write!(header, "\r\nContent-Type: {}\r\n\r\n", content_type).unwrap(); 402 | 403 | PreparedField { 404 | header: Cursor::new(header), 405 | stream, 406 | } 407 | } 408 | } 409 | 410 | impl Read for PreparedField<'_> { 411 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 412 | debug!("PreparedField::read()"); 413 | 414 | if !cursor_at_end(&self.header) { 415 | self.header.read(buf) 416 | } else { 417 | self.stream.read(buf) 418 | } 419 | } 420 | } 421 | 422 | impl fmt::Debug for PreparedField<'_> { 423 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 424 | f.debug_struct("PreparedField") 425 | .field("header", &self.header) 426 | .field("stream", &"Box") 427 | .finish() 428 | } 429 | } 430 | 431 | /// Conversion trait necessary for `Multipart::add_file()` to accept borrowed or owned strings 432 | /// and borrowed or owned paths 433 | pub trait IntoCowPath<'a> { 434 | /// Self-explanatory, hopefully 435 | fn into_cow_path(self) -> Cow<'a, Path>; 436 | } 437 | 438 | impl<'a> IntoCowPath<'a> for Cow<'a, Path> { 439 | fn into_cow_path(self) -> Cow<'a, Path> { 440 | self 441 | } 442 | } 443 | 444 | impl IntoCowPath<'static> for PathBuf { 445 | fn into_cow_path(self) -> Cow<'static, Path> { 446 | self.into() 447 | } 448 | } 449 | 450 | impl<'a> IntoCowPath<'a> for &'a Path { 451 | fn into_cow_path(self) -> Cow<'a, Path> { 452 | self.into() 453 | } 454 | } 455 | 456 | impl IntoCowPath<'static> for String { 457 | fn into_cow_path(self) -> Cow<'static, Path> { 458 | PathBuf::from(self).into() 459 | } 460 | } 461 | 462 | impl<'a> IntoCowPath<'a> for &'a str { 463 | fn into_cow_path(self) -> Cow<'a, Path> { 464 | Path::new(self).into() 465 | } 466 | } 467 | 468 | fn cursor_at_end>(cursor: &Cursor) -> bool { 469 | cursor.position() == (cursor.get_ref().as_ref().len() as u64) 470 | } 471 | -------------------------------------------------------------------------------- /src/multipart_crate/mod.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2016 `multipart` Crate Developers 2 | // Copyright 2025 github.com/sbstp 3 | // 4 | // Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be 7 | // copied, modified, or distributed except according to those terms. 8 | //! The client-side abstraction for multipart requests. Enabled with the `client` feature. 9 | //! 10 | //! Use this when sending POST requests with files to a server. 11 | #![allow(dead_code)] 12 | use mime::Mime; 13 | use rand::Rng; 14 | 15 | use std::borrow::Cow; 16 | use std::fs::File; 17 | use std::io; 18 | use std::io::prelude::*; 19 | 20 | use std::path::Path; 21 | 22 | pub mod lazy; 23 | 24 | const BOUNDARY_LEN: usize = 16; 25 | 26 | macro_rules! map_self { 27 | ($selff:expr, $try:expr) => { 28 | match $try { 29 | Ok(_) => Ok($selff), 30 | Err(err) => Err(err.into()), 31 | } 32 | }; 33 | } 34 | 35 | macro_rules! chain_result { 36 | ($first_expr:expr, $($try_expr:expr),*) => ( 37 | $first_expr $(.and_then(|_| $try_expr))* 38 | ); 39 | ($first_expr:expr, $($($arg:ident),+ -> $try_expr:expr),*) => ( 40 | $first_expr $(.and_then(|$($arg),+| $try_expr))* 41 | ); 42 | } 43 | 44 | /// The entry point of the client-side multipart API. 45 | /// 46 | /// Though they perform I/O, the `.write_*()` methods do not return `io::Result<_>` in order to 47 | /// facilitate method chaining. Upon the first error, all subsequent API calls will be no-ops until 48 | /// `.send()` is called, at which point the error will be reported. 49 | pub struct Multipart { 50 | writer: MultipartWriter<'static, S>, 51 | } 52 | 53 | impl Multipart<()> { 54 | /// Create a new `Multipart` to wrap a request. 55 | /// 56 | /// ## Returns Error 57 | /// If `req.open_stream()` returns an error. 58 | pub fn from_request(req: R) -> Result, R::Error> { 59 | let (boundary, stream) = open_stream(req, None)?; 60 | 61 | Ok(Multipart { 62 | writer: MultipartWriter::new(stream, boundary), 63 | }) 64 | } 65 | } 66 | 67 | impl Multipart { 68 | /// Write a text field to this multipart request. 69 | /// `name` and `val` can be either owned `String` or `&str`. 70 | /// 71 | /// ## Errors 72 | /// If something went wrong with the HTTP stream. 73 | pub fn write_text, V: AsRef>(&mut self, name: N, val: V) -> Result<&mut Self, S::Error> { 74 | map_self!(self, self.writer.write_text(name.as_ref(), val.as_ref())) 75 | } 76 | 77 | /// Open a file pointed to by `path` and write its contents to the multipart request, 78 | /// supplying its filename and guessing its `Content-Type` from its extension. 79 | /// 80 | /// If you want to set these values manually, or use another type that implements `Read`, 81 | /// use `.write_stream()`. 82 | /// 83 | /// `name` can be either `String` or `&str`, and `path` can be `PathBuf` or `&Path`. 84 | /// 85 | /// ## Errors 86 | /// If there was a problem opening the file (was a directory or didn't exist), 87 | /// or if something went wrong with the HTTP stream. 88 | pub fn write_file, P: AsRef>(&mut self, name: N, path: P) -> Result<&mut Self, S::Error> { 89 | let name = name.as_ref(); 90 | let path = path.as_ref(); 91 | 92 | map_self!(self, self.writer.write_file(name, path)) 93 | } 94 | 95 | /// Write a byte stream to the multipart request as a file field, supplying `filename` if given, 96 | /// and `content_type` if given or `"application/octet-stream"` if not. 97 | /// 98 | /// `name` can be either `String` or `&str`, and `read` can take the `Read` by-value or 99 | /// with an `&mut` borrow. 100 | /// 101 | /// ## Warning 102 | /// The given `Read` **must** be able to read to EOF (end of file/no more data), meaning 103 | /// `Read::read()` returns `Ok(0)`. If it never returns EOF it will be read to infinity 104 | /// and the request will never be completed. 105 | /// 106 | /// When using `SizedRequest` this also can cause out-of-control memory usage as the 107 | /// multipart data has to be written to an in-memory buffer so its size can be calculated. 108 | /// 109 | /// Use `Read::take()` if you wish to send data from a `Read` 110 | /// that will never return EOF otherwise. 111 | /// 112 | /// ## Errors 113 | /// If the reader returned an error, or if something went wrong with the HTTP stream. 114 | // RFC: How to format this declaration? 115 | pub fn write_stream, St: Read>( 116 | &mut self, 117 | name: N, 118 | stream: &mut St, 119 | filename: Option<&str>, 120 | content_type: Option, 121 | ) -> Result<&mut Self, S::Error> { 122 | let name = name.as_ref(); 123 | 124 | map_self!(self, self.writer.write_stream(stream, name, filename, content_type)) 125 | } 126 | 127 | /// Finalize the request and return the response from the server, or the last error if set. 128 | pub fn send(self) -> Result { 129 | self.writer 130 | .finish() 131 | .map_err(io::Error::into) 132 | .and_then(|body| body.finish()) 133 | } 134 | } 135 | 136 | // impl Multipart> 137 | // where 138 | // ::Error: From, 139 | // { 140 | // /// Create a new `Multipart` using the `SizedRequest` wrapper around `req`. 141 | // pub fn from_request_sized(req: R) -> Result { 142 | // Multipart::from_request(SizedRequest::from_request(req)) 143 | // } 144 | // } 145 | 146 | /// A trait describing an HTTP request that can be used to send multipart data. 147 | pub trait HttpRequest { 148 | /// The HTTP stream type that can be opend by this request, to which the multipart data will be 149 | /// written. 150 | type Stream: HttpStream; 151 | /// The error type for this request. 152 | /// Must be compatible with `io::Error` as well as `Self::HttpStream::Error` 153 | type Error: From + Into<::Error>; 154 | 155 | /// Set the `Content-Type` header to `multipart/form-data` and supply the `boundary` value. 156 | /// If `content_len` is given, set the `Content-Length` header to its value. 157 | /// 158 | /// Return `true` if any and all sanity checks passed and the stream is ready to be opened, 159 | /// or `false` otherwise. 160 | fn apply_headers(&mut self, boundary: &str, content_len: Option) -> bool; 161 | 162 | /// Open the request stream and return it or any error otherwise. 163 | fn open_stream(self) -> Result; 164 | } 165 | 166 | /// A trait describing an open HTTP stream that can be written to. 167 | pub trait HttpStream: Write { 168 | /// The request type that opened this stream. 169 | type Request: HttpRequest; 170 | /// The response type that will be returned after the request is completed. 171 | type Response; 172 | /// The error type for this stream. 173 | /// Must be compatible with `io::Error` as well as `Self::Request::Error`. 174 | type Error: From + From<::Error>; 175 | 176 | /// Finalize and close the stream and return the response object, or any error otherwise. 177 | fn finish(self) -> Result; 178 | } 179 | 180 | impl HttpRequest for () { 181 | type Stream = io::Sink; 182 | type Error = io::Error; 183 | 184 | fn apply_headers(&mut self, _: &str, _: Option) -> bool { 185 | true 186 | } 187 | fn open_stream(self) -> Result { 188 | Ok(io::sink()) 189 | } 190 | } 191 | 192 | impl HttpStream for io::Sink { 193 | type Request = (); 194 | type Response = (); 195 | type Error = io::Error; 196 | 197 | fn finish(self) -> Result { 198 | Ok(()) 199 | } 200 | } 201 | 202 | fn gen_boundary() -> String { 203 | rand::rng() 204 | .sample_iter(rand::distr::Alphanumeric) 205 | .take(BOUNDARY_LEN) 206 | .map(|c| c as char) 207 | .collect() 208 | } 209 | 210 | fn open_stream(mut req: R, content_len: Option) -> Result<(String, R::Stream), R::Error> { 211 | let boundary = gen_boundary(); 212 | req.apply_headers(&boundary, content_len); 213 | req.open_stream().map(|stream| (boundary, stream)) 214 | } 215 | 216 | struct MultipartWriter<'a, W> { 217 | inner: W, 218 | boundary: Cow<'a, str>, 219 | data_written: bool, 220 | } 221 | 222 | impl<'a, W: Write> MultipartWriter<'a, W> { 223 | fn new>>(inner: W, boundary: B) -> Self { 224 | MultipartWriter { 225 | inner, 226 | boundary: boundary.into(), 227 | data_written: false, 228 | } 229 | } 230 | 231 | fn write_boundary(&mut self) -> io::Result<()> { 232 | if self.data_written { 233 | self.inner.write_all(b"\r\n")?; 234 | } 235 | 236 | write!(self.inner, "--{}\r\n", self.boundary) 237 | } 238 | 239 | fn write_text(&mut self, name: &str, text: &str) -> io::Result<()> { 240 | chain_result! { 241 | self.write_field_headers(name, None, None), 242 | self.inner.write_all(text.as_bytes()) 243 | } 244 | } 245 | 246 | fn write_file(&mut self, name: &str, path: &Path) -> io::Result<()> { 247 | let (content_type, filename) = mime_filename(path); 248 | let mut file = File::open(path)?; 249 | self.write_stream(&mut file, name, filename, Some(content_type)) 250 | } 251 | 252 | fn write_stream( 253 | &mut self, 254 | stream: &mut S, 255 | name: &str, 256 | filename: Option<&str>, 257 | content_type: Option, 258 | ) -> io::Result<()> { 259 | // This is necessary to make sure it is interpreted as a file on the server end. 260 | let content_type = Some(content_type.unwrap_or(mime::APPLICATION_OCTET_STREAM)); 261 | 262 | chain_result! { 263 | self.write_field_headers(name, filename, content_type), 264 | io::copy(stream, &mut self.inner), 265 | Ok(()) 266 | } 267 | } 268 | 269 | fn write_field_headers( 270 | &mut self, 271 | name: &str, 272 | filename: Option<&str>, 273 | content_type: Option, 274 | ) -> io::Result<()> { 275 | chain_result! { 276 | // Write the first boundary, or the boundary for the previous field. 277 | self.write_boundary(), 278 | { self.data_written = true; Ok(()) }, 279 | write!(self.inner, "Content-Disposition: form-data; name=\"{}\"", name), 280 | filename.map(|filename| write!(self.inner, "; filename=\"{}\"", filename)) 281 | .unwrap_or(Ok(())), 282 | content_type.map(|content_type| write!(self.inner, "\r\nContent-Type: {}", content_type)) 283 | .unwrap_or(Ok(())), 284 | self.inner.write_all(b"\r\n\r\n") 285 | } 286 | } 287 | 288 | fn finish(mut self) -> io::Result { 289 | if self.data_written { 290 | self.inner.write_all(b"\r\n")?; 291 | } 292 | 293 | // always write the closing boundary, even for empty bodies 294 | // trailing CRLF is optional but Actix requires it due to a naive implementation: 295 | // https://github.com/actix/actix-web/issues/598 296 | write!(self.inner, "--{}--\r\n", self.boundary)?; 297 | Ok(self.inner) 298 | } 299 | } 300 | 301 | fn mime_filename(path: &Path) -> (Mime, Option<&str>) { 302 | let content_type = mime_guess::from_path(path); 303 | let filename = opt_filename(path); 304 | (content_type.first_or_octet_stream(), filename) 305 | } 306 | 307 | fn opt_filename(path: &Path) -> Option<&str> { 308 | path.file_name().and_then(|filename| filename.to_str()) 309 | } 310 | -------------------------------------------------------------------------------- /src/parsing/body_reader.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, BufRead, BufReader, Read, Take}; 2 | 3 | use http::header::{HeaderMap, HeaderValue, CONTENT_LENGTH, TRANSFER_ENCODING}; 4 | 5 | use crate::error::{InvalidResponseKind, Result}; 6 | use crate::parsing::chunked_reader::ChunkedReader; 7 | use crate::streams::BaseStream; 8 | 9 | #[derive(Debug)] 10 | pub enum BodyReader { 11 | Chunked(ChunkedReader), 12 | Length(Take>), 13 | Close(BufReader), 14 | } 15 | 16 | impl Read for BodyReader { 17 | #[inline] 18 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 19 | match self { 20 | BodyReader::Chunked(r) => r.read(buf), 21 | BodyReader::Length(r) => r.read(buf), 22 | BodyReader::Close(r) => r.read(buf), 23 | } 24 | } 25 | } 26 | 27 | impl BufRead for BodyReader { 28 | #[inline] 29 | fn fill_buf(&mut self) -> io::Result<&[u8]> { 30 | match self { 31 | BodyReader::Chunked(r) => r.fill_buf(), 32 | BodyReader::Length(r) => r.fill_buf(), 33 | BodyReader::Close(r) => r.fill_buf(), 34 | } 35 | } 36 | 37 | #[inline] 38 | fn consume(&mut self, amt: usize) { 39 | match self { 40 | BodyReader::Chunked(r) => r.consume(amt), 41 | BodyReader::Length(r) => r.consume(amt), 42 | BodyReader::Close(r) => r.consume(amt), 43 | } 44 | } 45 | } 46 | 47 | fn is_chunked(headers: &HeaderMap) -> bool { 48 | headers 49 | .get_all(TRANSFER_ENCODING) 50 | .into_iter() 51 | .filter_map(|val| val.to_str().ok()) 52 | .any(|val| { 53 | val.split(',') 54 | .map(|s| s.trim()) 55 | .any(|s| s.eq_ignore_ascii_case("chunked")) 56 | }) 57 | } 58 | 59 | fn parse_content_length(val: &HeaderValue) -> Result { 60 | let val = val.to_str().map_err(|_| InvalidResponseKind::ContentLength)?; 61 | let val = val.parse::().map_err(|_| InvalidResponseKind::ContentLength)?; 62 | Ok(val) 63 | } 64 | 65 | fn is_content_length(headers: &HeaderMap) -> Result> { 66 | let mut last = None; 67 | for val in headers.get_all(CONTENT_LENGTH) { 68 | let val = parse_content_length(val)?; 69 | last = Some(match last { 70 | None => val, 71 | Some(last) if last == val => val, 72 | _ => { 73 | return Err(InvalidResponseKind::ContentLength.into()); 74 | } 75 | }); 76 | } 77 | Ok(last) 78 | } 79 | 80 | impl BodyReader { 81 | pub fn new(headers: &HeaderMap, reader: BufReader) -> Result { 82 | if is_chunked(headers) { 83 | debug!("creating a chunked body reader"); 84 | Ok(BodyReader::Chunked(ChunkedReader::new(reader))) 85 | } else if let Some(val) = is_content_length(headers)? { 86 | debug!("creating a length body reader"); 87 | Ok(BodyReader::Length(reader.take(val))) 88 | } else { 89 | debug!("creating close reader"); 90 | Ok(BodyReader::Close(reader)) 91 | } 92 | } 93 | } 94 | 95 | #[test] 96 | fn test_is_chunked_false() { 97 | let mut headers = HeaderMap::new(); 98 | headers.insert("content-encoding", HeaderValue::from_static("gzip")); 99 | assert!(!is_chunked(&headers)); 100 | } 101 | 102 | #[test] 103 | fn test_is_chunked_simple() { 104 | let mut headers = HeaderMap::new(); 105 | headers.insert("transfer-encoding", HeaderValue::from_static("chunked")); 106 | assert!(is_chunked(&headers)); 107 | } 108 | 109 | #[test] 110 | fn test_is_chunked_multi() { 111 | let mut headers = HeaderMap::new(); 112 | headers.insert("transfer-encoding", HeaderValue::from_static("gzip, chunked")); 113 | assert!(is_chunked(&headers)); 114 | } 115 | 116 | #[test] 117 | fn test_parse_content_length_ok() { 118 | assert_eq!(parse_content_length(&HeaderValue::from_static("17")).ok(), Some(17)); 119 | } 120 | 121 | #[test] 122 | fn test_parse_content_length_err() { 123 | assert!(parse_content_length(&HeaderValue::from_static("XD")).is_err()); 124 | } 125 | 126 | #[test] 127 | fn test_is_content_length_none() { 128 | let headers = HeaderMap::new(); 129 | assert_eq!(is_content_length(&headers).ok(), Some(None)); 130 | } 131 | 132 | #[test] 133 | fn test_is_content_length_one() { 134 | let mut headers = HeaderMap::new(); 135 | headers.insert("content-length", HeaderValue::from_static("88")); 136 | assert_eq!(is_content_length(&headers).ok(), Some(Some(88))); 137 | } 138 | 139 | #[test] 140 | fn test_is_content_length_many_ok() { 141 | let mut headers = HeaderMap::new(); 142 | headers.append("content-length", HeaderValue::from_static("88")); 143 | headers.append("content-length", HeaderValue::from_static("88")); 144 | 145 | assert_eq!(headers.get_all("content-length").iter().count(), 2); 146 | assert_eq!(is_content_length(&headers).ok(), Some(Some(88))); 147 | } 148 | 149 | #[test] 150 | fn test_is_content_length_many_err() { 151 | let mut headers = HeaderMap::new(); 152 | headers.append("content-length", HeaderValue::from_static("88")); 153 | headers.append("content-length", HeaderValue::from_static("90")); 154 | 155 | assert_eq!(headers.get_all("content-length").iter().count(), 2); 156 | assert!(is_content_length(&headers).is_err()); 157 | } 158 | -------------------------------------------------------------------------------- /src/parsing/buffers.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, BufRead, BufReader, Read, Write}; 2 | 3 | pub fn read_line(reader: &mut BufReader, buf: &mut Vec, max_buf_len: u64) -> io::Result 4 | where 5 | R: Read, 6 | { 7 | buf.clear(); 8 | let n = reader.take(max_buf_len).read_until(b'\n', buf)?; 9 | 10 | if buf.ends_with(b"\r\n") { 11 | buf.truncate(buf.len() - 2); 12 | } else if buf.ends_with(b"\n") { 13 | buf.truncate(buf.len() - 1); 14 | } else { 15 | return Err(io::ErrorKind::UnexpectedEof.into()); 16 | } 17 | 18 | Ok(n) 19 | } 20 | 21 | pub fn read_line_strict(reader: &mut BufReader, buf: &mut Vec, max_buf_len: u64) -> io::Result 22 | where 23 | R: Read, 24 | { 25 | buf.clear(); 26 | let mut reader = reader.take(max_buf_len); 27 | let mut n = 0; 28 | 29 | loop { 30 | let k = reader.read_until(b'\n', buf)?; 31 | n += k; 32 | 33 | if k == 0 || buf[buf.len() - 1] != b'\n' { 34 | return Err(io::ErrorKind::UnexpectedEof.into()); 35 | } 36 | 37 | if k >= 2 && buf[buf.len() - 2] == b'\r' && buf[buf.len() - 1] == b'\n' { 38 | buf.truncate(buf.len() - 2); 39 | return Ok(n); 40 | } 41 | } 42 | } 43 | 44 | pub fn read_line_ending(reader: &mut BufReader) -> io::Result 45 | where 46 | R: Read, 47 | { 48 | let mut b = [0]; 49 | reader.read_exact(&mut b)?; 50 | 51 | if &b == b"\r" { 52 | reader.read_exact(&mut b)?; 53 | } 54 | 55 | Ok(&b == b"\n") 56 | } 57 | 58 | pub fn trim_byte(byte: u8, buf: &[u8]) -> &[u8] { 59 | trim_byte_left(byte, trim_byte_right(byte, buf)) 60 | } 61 | 62 | pub fn trim_byte_left(byte: u8, buf: &[u8]) -> &[u8] { 63 | buf.iter().position(|b| *b != byte).map_or(&[], |n| &buf[n..]) 64 | } 65 | 66 | pub fn trim_byte_right(byte: u8, buf: &[u8]) -> &[u8] { 67 | buf.iter().rposition(|b| *b != byte).map_or(&[], |n| &buf[..=n]) 68 | } 69 | 70 | pub fn replace_byte(byte: u8, replace: u8, buf: &mut [u8]) { 71 | for val in buf { 72 | if *val == byte { 73 | *val = replace; 74 | } 75 | } 76 | } 77 | 78 | #[derive(Debug)] 79 | pub struct BufReaderWrite { 80 | inner: BufReader, 81 | } 82 | 83 | impl BufReaderWrite { 84 | pub fn new(inner: R) -> BufReaderWrite { 85 | BufReaderWrite { 86 | inner: BufReader::new(inner), 87 | } 88 | } 89 | } 90 | 91 | impl Read for BufReaderWrite { 92 | #[inline] 93 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 94 | self.inner.read(buf) 95 | } 96 | } 97 | 98 | impl Write for BufReaderWrite { 99 | #[inline] 100 | fn write(&mut self, buf: &[u8]) -> io::Result { 101 | self.inner.get_mut().write(buf) 102 | } 103 | 104 | #[inline] 105 | fn flush(&mut self) -> io::Result<()> { 106 | self.inner.get_mut().flush() 107 | } 108 | } 109 | 110 | impl std::ops::Deref for BufReaderWrite { 111 | type Target = BufReader; 112 | fn deref(&self) -> &BufReader { 113 | &self.inner 114 | } 115 | } 116 | 117 | impl std::ops::DerefMut for BufReaderWrite { 118 | fn deref_mut(&mut self) -> &mut BufReader { 119 | &mut self.inner 120 | } 121 | } 122 | 123 | #[test] 124 | fn test_read_line_lf() { 125 | let mut reader = BufReader::new(&b"hello\nworld\n"[..]); 126 | let mut line = Vec::new(); 127 | 128 | assert_eq!(read_line(&mut reader, &mut line, u64::MAX).ok(), Some(6)); 129 | assert_eq!(line, b"hello"); 130 | 131 | assert_eq!(read_line(&mut reader, &mut line, u64::MAX).ok(), Some(6)); 132 | assert_eq!(line, b"world"); 133 | } 134 | 135 | #[test] 136 | fn test_read_line_crlf() { 137 | let mut reader = BufReader::new(&b"hello\r\nworld\r\n"[..]); 138 | let mut line = Vec::new(); 139 | 140 | assert_eq!(read_line(&mut reader, &mut line, u64::MAX).ok(), Some(7)); 141 | assert_eq!(line, b"hello"); 142 | 143 | assert_eq!(read_line(&mut reader, &mut line, u64::MAX).ok(), Some(7)); 144 | assert_eq!(line, b"world"); 145 | } 146 | 147 | #[test] 148 | fn test_read_line_empty_crlf() { 149 | let mut reader = BufReader::new(&b"\r\n"[..]); 150 | let mut line = Vec::new(); 151 | 152 | assert_eq!(read_line(&mut reader, &mut line, u64::MAX).ok(), Some(2)); 153 | assert_eq!(line, b""); 154 | } 155 | 156 | #[test] 157 | fn test_read_line_empty_lf() { 158 | let mut reader = BufReader::new(&b"\n"[..]); 159 | let mut line = Vec::new(); 160 | 161 | assert_eq!(read_line(&mut reader, &mut line, u64::MAX).ok(), Some(1)); 162 | assert_eq!(line, b""); 163 | } 164 | 165 | #[test] 166 | fn test_read_line_beyond_limit() { 167 | let mut reader = BufReader::new(&b"1234567890\n"[..]); 168 | let mut line = Vec::new(); 169 | 170 | assert_eq!( 171 | read_line(&mut reader, &mut line, 5).unwrap_err().kind(), 172 | io::ErrorKind::UnexpectedEof 173 | ); 174 | assert_eq!(line, b"12345"); 175 | } 176 | 177 | #[test] 178 | fn test_read_line_strict() { 179 | let mut reader = BufReader::new(&b"foo\r\nbar\r\n"[..]); 180 | let mut line = Vec::new(); 181 | 182 | assert_eq!(read_line_strict(&mut reader, &mut line, u64::MAX).ok(), Some(3 + 2)); 183 | assert_eq!(line, b"foo"); 184 | } 185 | 186 | #[test] 187 | fn test_read_line_strict_empty_crlf() { 188 | let mut reader = BufReader::new(&b"\r\n"[..]); 189 | let mut line = Vec::new(); 190 | 191 | assert_eq!(read_line_strict(&mut reader, &mut line, u64::MAX).ok(), Some(2)); 192 | assert_eq!(line, b""); 193 | } 194 | 195 | #[test] 196 | fn test_read_line_strict_missing_crlf() { 197 | let mut reader = BufReader::new(&b"foo\n"[..]); 198 | let mut line = Vec::new(); 199 | 200 | assert_eq!( 201 | read_line_strict(&mut reader, &mut line, u64::MAX).unwrap_err().kind(), 202 | io::ErrorKind::UnexpectedEof 203 | ); 204 | assert_eq!(line, b"foo\n"); 205 | } 206 | 207 | #[test] 208 | fn test_read_line_strict_inner_lf() { 209 | let mut reader = BufReader::new(&b"123\n456\n789\n0\r\nABC"[..]); 210 | let mut line = Vec::new(); 211 | 212 | assert_eq!( 213 | read_line_strict(&mut reader, &mut line, u64::MAX).ok(), 214 | Some(10 + 3 + 2) 215 | ); 216 | assert_eq!(line, b"123\n456\n789\n0"); 217 | } 218 | 219 | #[test] 220 | fn test_read_line_strict_inner_cr() { 221 | let mut reader = BufReader::new(&b"123\r456\r789\r0\r\nXYZ"[..]); 222 | let mut line = Vec::new(); 223 | 224 | assert_eq!( 225 | read_line_strict(&mut reader, &mut line, u64::MAX).ok(), 226 | Some(10 + 3 + 2) 227 | ); 228 | assert_eq!(line, b"123\r456\r789\r0"); 229 | } 230 | 231 | #[test] 232 | fn test_trim_byte() { 233 | assert_eq!(trim_byte(b' ', b" hello "), b"hello"); 234 | assert_eq!(trim_byte(b' ', b"hello"), b"hello"); 235 | assert_eq!(trim_byte(b' ', b""), b""); 236 | } 237 | 238 | #[test] 239 | fn test_trim_byte_left() { 240 | assert_eq!(trim_byte_left(b' ', b" hello"), b"hello"); 241 | assert_eq!(trim_byte_left(b' ', b"hello"), b"hello"); 242 | assert_eq!(trim_byte_left(b' ', b""), b""); 243 | } 244 | 245 | #[test] 246 | fn test_trim_byte_right() { 247 | assert_eq!(trim_byte_right(b' ', b"hello "), b"hello"); 248 | assert_eq!(trim_byte_right(b' ', b"hello"), b"hello"); 249 | assert_eq!(trim_byte_right(b' ', b""), b""); 250 | } 251 | -------------------------------------------------------------------------------- /src/parsing/chunked_reader.rs: -------------------------------------------------------------------------------- 1 | use std::cmp; 2 | use std::io::{self, BufRead, BufReader, Read}; 3 | use std::str; 4 | 5 | use crate::error::InvalidResponseKind; 6 | use crate::parsing::buffers; 7 | 8 | fn parse_chunk_size(line: &[u8]) -> io::Result { 9 | line.iter() 10 | .position(|&b| b == b';') 11 | .map_or_else(|| str::from_utf8(line), |idx| str::from_utf8(&line[..idx])) 12 | .map_err(|_| InvalidResponseKind::ChunkSize) 13 | .and_then(|line| usize::from_str_radix(line.trim(), 16).map_err(|_| InvalidResponseKind::ChunkSize)) 14 | .map_err(|e| e.into()) 15 | } 16 | 17 | #[derive(Debug)] 18 | pub struct ChunkedReader 19 | where 20 | R: Read, 21 | { 22 | inner: BufReader, 23 | buffer: Vec, 24 | consumed: usize, // bytes consumed from `buffer` 25 | remaining: usize, // bytes remaining until next chunk 26 | reached_eof: bool, 27 | } 28 | 29 | impl ChunkedReader 30 | where 31 | R: Read, 32 | { 33 | pub fn new(reader: BufReader) -> ChunkedReader { 34 | ChunkedReader { 35 | inner: reader, 36 | buffer: Vec::new(), 37 | consumed: 0, 38 | remaining: 0, 39 | reached_eof: false, 40 | } 41 | } 42 | 43 | fn read_chunk_size(&mut self) -> io::Result { 44 | buffers::read_line(&mut self.inner, &mut self.buffer, 128)?; 45 | if self.buffer.is_empty() { 46 | return Err(io::ErrorKind::UnexpectedEof.into()); 47 | } 48 | parse_chunk_size(&self.buffer) 49 | } 50 | } 51 | 52 | impl BufRead for ChunkedReader 53 | where 54 | R: Read, 55 | { 56 | fn fill_buf(&mut self) -> io::Result<&[u8]> { 57 | const MAX_BUFFER_LEN: usize = 64 * 1024; 58 | 59 | if self.buffer.len() == self.consumed && !(self.remaining == 0 && self.reached_eof) { 60 | if self.remaining == 0 { 61 | self.remaining = self.read_chunk_size()?; 62 | if self.remaining == 0 { 63 | self.reached_eof = true; 64 | } 65 | } 66 | 67 | self.buffer.resize(cmp::min(self.remaining, MAX_BUFFER_LEN), 0); 68 | self.inner.read_exact(&mut self.buffer)?; 69 | self.consumed = 0; 70 | self.remaining -= self.buffer.len(); 71 | 72 | if self.remaining == 0 && !buffers::read_line_ending(&mut self.inner)? { 73 | self.buffer.clear(); 74 | self.reached_eof = true; 75 | 76 | return Err(InvalidResponseKind::Chunk.into()); 77 | } 78 | } 79 | 80 | Ok(&self.buffer[self.consumed..]) 81 | } 82 | 83 | fn consume(&mut self, amt: usize) { 84 | self.consumed = cmp::min(self.consumed + amt, self.buffer.len()); 85 | } 86 | } 87 | 88 | impl Read for ChunkedReader 89 | where 90 | R: Read, 91 | { 92 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 93 | let n = self.fill_buf()?.read(buf)?; 94 | self.consume(n); 95 | Ok(n) 96 | } 97 | } 98 | 99 | #[test] 100 | fn test_read_works() { 101 | let msg = b"4\r\nwiki\r\n5\r\npedia\r\nE\r\n in\r\n\r\nchunks.\r\n0\r\n\r\n"; 102 | let mut reader = ChunkedReader::new(BufReader::new(&msg[..])); 103 | let mut s = String::new(); 104 | reader.read_to_string(&mut s).unwrap(); 105 | assert_eq!(s, "wikipedia in\r\n\r\nchunks."); 106 | } 107 | 108 | #[test] 109 | fn test_read_empty() { 110 | let msg = b"0\r\n\r\n"; 111 | let mut reader = ChunkedReader::new(BufReader::new(&msg[..])); 112 | let mut s = String::new(); 113 | reader.read_to_string(&mut s).unwrap(); 114 | assert_eq!(s, ""); 115 | } 116 | 117 | #[test] 118 | fn test_read_invalid_empty() { 119 | let msg = b""; 120 | let mut reader = ChunkedReader::new(BufReader::new(&msg[..])); 121 | let mut s = String::new(); 122 | assert!(reader.read_to_string(&mut s).is_err()); 123 | } 124 | 125 | #[test] 126 | fn test_read_invalid_chunk() { 127 | let msg = b"4\r\nwik"; 128 | let mut reader = ChunkedReader::new(BufReader::new(&msg[..])); 129 | let mut s = String::new(); 130 | assert_eq!( 131 | reader.read_to_string(&mut s).err().unwrap().kind(), 132 | io::ErrorKind::UnexpectedEof 133 | ); 134 | } 135 | 136 | #[test] 137 | fn test_read_invalid_no_terminating_chunk() { 138 | let msg = b"4\r\nwiki"; 139 | let mut reader = ChunkedReader::new(BufReader::new(&msg[..])); 140 | let mut s = String::new(); 141 | assert_eq!( 142 | reader.read_to_string(&mut s).err().unwrap().kind(), 143 | io::ErrorKind::UnexpectedEof 144 | ); 145 | } 146 | 147 | #[test] 148 | fn test_read_invalid_bad_terminating_chunk() { 149 | let msg = b"4\r\nwiki\r\n0\r\n"; 150 | let mut reader = ChunkedReader::new(BufReader::new(&msg[..])); 151 | let mut s = String::new(); 152 | assert_eq!( 153 | reader.read_to_string(&mut s).err().unwrap().kind(), 154 | io::ErrorKind::UnexpectedEof 155 | ); 156 | } 157 | -------------------------------------------------------------------------------- /src/parsing/compressed_reader.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read}; 2 | 3 | #[cfg(feature = "flate2")] 4 | use flate2::bufread::{DeflateDecoder, GzDecoder}; 5 | use http::header::HeaderMap; 6 | #[cfg(feature = "flate2")] 7 | use http::header::{CONTENT_ENCODING, TRANSFER_ENCODING}; 8 | #[cfg(feature = "flate2")] 9 | use http::Method; 10 | 11 | use crate::error::Result; 12 | use crate::parsing::body_reader::BodyReader; 13 | use crate::request::PreparedRequest; 14 | 15 | #[allow(clippy::large_enum_variant)] 16 | #[derive(Debug)] 17 | pub enum CompressedReader { 18 | Plain(BodyReader), 19 | #[cfg(feature = "flate2")] 20 | Deflate(DeflateDecoder), 21 | #[cfg(feature = "flate2")] 22 | Gzip(GzDecoder), 23 | } 24 | 25 | #[cfg(feature = "flate2")] 26 | fn have_encoding_item(value: &str, enc: &str) -> bool { 27 | value.split(',').map(|s| s.trim()).any(|s| s.eq_ignore_ascii_case(enc)) 28 | } 29 | 30 | #[cfg(feature = "flate2")] 31 | fn have_encoding_content_encoding(headers: &HeaderMap, enc: &str) -> bool { 32 | headers 33 | .get_all(CONTENT_ENCODING) 34 | .into_iter() 35 | .filter_map(|val| val.to_str().ok()) 36 | .any(|val| have_encoding_item(val, enc)) 37 | } 38 | 39 | #[cfg(feature = "flate2")] 40 | fn have_encoding_transfer_encoding(headers: &HeaderMap, enc: &str) -> bool { 41 | headers 42 | .get_all(TRANSFER_ENCODING) 43 | .into_iter() 44 | .filter_map(|val| val.to_str().ok()) 45 | .any(|val| have_encoding_item(val, enc)) 46 | } 47 | 48 | #[cfg(feature = "flate2")] 49 | fn have_encoding(headers: &HeaderMap, enc: &str) -> bool { 50 | have_encoding_content_encoding(headers, enc) || have_encoding_transfer_encoding(headers, enc) 51 | } 52 | 53 | impl CompressedReader { 54 | #[cfg(feature = "flate2")] 55 | pub fn new(headers: &HeaderMap, request: &PreparedRequest, reader: BodyReader) -> Result { 56 | if request.method() != Method::HEAD { 57 | if have_encoding(headers, "gzip") { 58 | debug!("creating gzip decoder"); 59 | return Ok(CompressedReader::Gzip(GzDecoder::new(reader))); 60 | } 61 | 62 | if have_encoding(headers, "deflate") { 63 | debug!("creating deflate decoder"); 64 | return Ok(CompressedReader::Deflate(DeflateDecoder::new(reader))); 65 | } 66 | } 67 | debug!("creating plain reader"); 68 | Ok(CompressedReader::Plain(reader)) 69 | } 70 | 71 | #[cfg(not(feature = "flate2"))] 72 | pub fn new(_: &HeaderMap, _: &PreparedRequest, reader: BodyReader) -> Result { 73 | Ok(CompressedReader::Plain(reader)) 74 | } 75 | } 76 | 77 | impl Read for CompressedReader { 78 | #[inline] 79 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 80 | // TODO: gzip does not read until EOF, leaving some data in the buffer. 81 | match self { 82 | CompressedReader::Plain(s) => s.read(buf), 83 | #[cfg(feature = "flate2")] 84 | CompressedReader::Deflate(s) => s.read(buf), 85 | #[cfg(feature = "flate2")] 86 | CompressedReader::Gzip(s) => s.read(buf), 87 | } 88 | } 89 | } 90 | 91 | #[cfg(test)] 92 | mod tests { 93 | use std::io::prelude::*; 94 | 95 | #[cfg(feature = "flate2")] 96 | use flate2::{ 97 | write::{DeflateEncoder, GzEncoder}, 98 | Compression, 99 | }; 100 | #[cfg(feature = "flate2")] 101 | use http::header::{HeaderMap, HeaderValue}; 102 | use http::Method; 103 | 104 | #[cfg(feature = "flate2")] 105 | use super::have_encoding; 106 | use crate::parsing::response::parse_response; 107 | use crate::streams::BaseStream; 108 | use crate::PreparedRequest; 109 | 110 | #[test] 111 | #[cfg(feature = "flate2")] 112 | fn test_have_encoding_none() { 113 | let mut headers = HeaderMap::new(); 114 | headers.insert("content-encoding", HeaderValue::from_static("gzip")); 115 | assert!(!have_encoding(&headers, "deflate")); 116 | } 117 | 118 | #[test] 119 | #[cfg(feature = "flate2")] 120 | fn test_have_encoding_content_encoding_simple() { 121 | let mut headers = HeaderMap::new(); 122 | headers.insert("content-encoding", HeaderValue::from_static("gzip")); 123 | assert!(have_encoding(&headers, "gzip")); 124 | } 125 | 126 | #[test] 127 | #[cfg(feature = "flate2")] 128 | fn test_have_encoding_content_encoding_multi() { 129 | let mut headers = HeaderMap::new(); 130 | headers.insert("content-encoding", HeaderValue::from_static("identity, deflate")); 131 | assert!(have_encoding(&headers, "deflate")); 132 | } 133 | 134 | #[test] 135 | #[cfg(feature = "flate2")] 136 | fn test_have_encoding_transfer_encoding_simple() { 137 | let mut headers = HeaderMap::new(); 138 | headers.insert("transfer-encoding", HeaderValue::from_static("deflate")); 139 | assert!(have_encoding(&headers, "deflate")); 140 | } 141 | 142 | #[test] 143 | #[cfg(feature = "flate2")] 144 | fn test_have_encoding_transfer_encoding_multi() { 145 | let mut headers = HeaderMap::new(); 146 | headers.insert("transfer-encoding", HeaderValue::from_static("gzip, chunked")); 147 | assert!(have_encoding(&headers, "gzip")); 148 | } 149 | 150 | #[test] 151 | fn test_stream_plain() { 152 | let payload = b"Hello world!!!!!!!!"; 153 | 154 | let mut buf: Vec = Vec::new(); 155 | let _ = write!(buf, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", payload.len()); 156 | buf.extend(payload); 157 | 158 | let req = PreparedRequest::new(Method::GET, "http://google.ca"); 159 | 160 | let sock = BaseStream::mock(buf); 161 | let response = parse_response(sock, &req, req.url()).unwrap(); 162 | assert_eq!(response.text().unwrap(), "Hello world!!!!!!!!"); 163 | } 164 | 165 | #[test] 166 | #[cfg(feature = "flate2")] 167 | fn test_stream_deflate() { 168 | let mut payload = Vec::new(); 169 | let mut enc = DeflateEncoder::new(&mut payload, Compression::default()); 170 | enc.write_all(b"Hello world!!!!!!!!").unwrap(); 171 | enc.finish().unwrap(); 172 | 173 | let mut buf: Vec = Vec::new(); 174 | let _ = write!( 175 | buf, 176 | "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Encoding: deflate\r\n\r\n", 177 | payload.len() 178 | ); 179 | buf.extend(payload); 180 | 181 | let req = PreparedRequest::new(Method::GET, "http://google.ca"); 182 | 183 | let sock = BaseStream::mock(buf); 184 | let response = parse_response(sock, &req, req.url()).unwrap(); 185 | assert_eq!(response.text().unwrap(), "Hello world!!!!!!!!"); 186 | } 187 | 188 | #[test] 189 | #[cfg(feature = "flate2")] 190 | fn test_stream_gzip() { 191 | let mut payload = Vec::new(); 192 | let mut enc = GzEncoder::new(&mut payload, Compression::default()); 193 | enc.write_all(b"Hello world!!!!!!!!").unwrap(); 194 | enc.finish().unwrap(); 195 | 196 | let mut buf: Vec = Vec::new(); 197 | let _ = write!( 198 | buf, 199 | "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Encoding: gzip\r\n\r\n", 200 | payload.len() 201 | ); 202 | buf.extend(payload); 203 | 204 | let req = PreparedRequest::new(Method::GET, "http://google.ca"); 205 | 206 | let sock = BaseStream::mock(buf); 207 | let response = parse_response(sock, &req, req.url()).unwrap(); 208 | 209 | assert_eq!(response.text().unwrap(), "Hello world!!!!!!!!"); 210 | } 211 | 212 | #[test] 213 | #[cfg(feature = "flate2")] 214 | fn test_no_body_with_gzip() { 215 | let buf = b"HTTP/1.1 200 OK\r\ncontent-encoding: gzip\r\n\r\n"; 216 | 217 | let req = PreparedRequest::new(Method::GET, "http://google.ca"); 218 | let sock = BaseStream::mock(buf.to_vec()); 219 | // Fixed by the move from libflate to flate2 220 | assert!(parse_response(sock, &req, req.url()).is_ok()); 221 | } 222 | 223 | #[test] 224 | #[cfg(feature = "flate2")] 225 | fn test_no_body_with_gzip_head() { 226 | let buf = b"HTTP/1.1 200 OK\r\ncontent-encoding: gzip\r\n\r\n"; 227 | 228 | let req = PreparedRequest::new(Method::HEAD, "http://google.ca"); 229 | let sock = BaseStream::mock(buf.to_vec()); 230 | assert!(parse_response(sock, &req, req.url()).is_ok()); 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /src/parsing/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod body_reader; 2 | pub mod buffers; 3 | pub mod chunked_reader; 4 | pub mod compressed_reader; 5 | pub mod response; 6 | pub mod response_reader; 7 | #[cfg(feature = "charsets")] 8 | pub mod text_reader; 9 | 10 | pub use self::response::{parse_response, Response}; 11 | pub use self::response_reader::ResponseReader; 12 | #[cfg(feature = "charsets")] 13 | pub use self::text_reader::TextReader; 14 | -------------------------------------------------------------------------------- /src/parsing/response.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, BufReader, Read, Write}; 2 | use std::str; 3 | 4 | use http::{ 5 | header::{HeaderName, HeaderValue, TRANSFER_ENCODING}, 6 | HeaderMap, StatusCode, 7 | }; 8 | use url::Url; 9 | 10 | use crate::error::{ErrorKind, InvalidResponseKind, Result}; 11 | use crate::parsing::buffers::{self, trim_byte}; 12 | use crate::parsing::{body_reader::BodyReader, compressed_reader::CompressedReader, ResponseReader}; 13 | use crate::request::PreparedRequest; 14 | use crate::streams::BaseStream; 15 | 16 | #[cfg(feature = "charsets")] 17 | use crate::{charsets::Charset, parsing::TextReader}; 18 | 19 | #[cfg(feature = "json")] 20 | use serde::de::DeserializeOwned; 21 | 22 | pub fn parse_response_head(reader: &mut BufReader, max_headers: usize) -> Result<(StatusCode, HeaderMap)> 23 | where 24 | R: Read, 25 | { 26 | const MAX_LINE_LEN: u64 = 16 * 1024; 27 | 28 | let mut line = Vec::new(); 29 | let mut headers = HeaderMap::new(); 30 | 31 | // status line 32 | let status: StatusCode = { 33 | buffers::read_line(reader, &mut line, MAX_LINE_LEN)?; 34 | let mut parts = line.split(|&b| b == b' ').filter(|x| !x.is_empty()); 35 | 36 | let _ = parts.next().ok_or(InvalidResponseKind::StatusLine)?; 37 | let code = parts.next().ok_or(InvalidResponseKind::StatusLine)?; 38 | 39 | str::from_utf8(code) 40 | .map_err(|_| InvalidResponseKind::StatusCode)? 41 | .parse() 42 | .map_err(|_| InvalidResponseKind::StatusCode)? 43 | }; 44 | 45 | // headers 46 | loop { 47 | buffers::read_line_strict(reader, &mut line, MAX_LINE_LEN)?; 48 | if line.is_empty() { 49 | break; 50 | } else if headers.len() == max_headers { 51 | return Err(InvalidResponseKind::Header.into()); 52 | } 53 | 54 | let col = line 55 | .iter() 56 | .position(|&c| c == b':') 57 | .ok_or(InvalidResponseKind::Header)?; 58 | 59 | buffers::replace_byte(b'\n', b' ', &mut line[col + 1..]); 60 | 61 | let header = trim_byte(b' ', &line[..col]); 62 | let value = trim_byte(b' ', &line[col + 1..]); 63 | 64 | let header = match HeaderName::from_bytes(header) { 65 | Ok(val) => val, 66 | Err(err) => { 67 | warn!("Dropped invalid response header: {}", err); 68 | continue; 69 | } 70 | }; 71 | 72 | headers.append(header, HeaderValue::from_bytes(value).map_err(http::Error::from)?); 73 | } 74 | 75 | Ok((status, headers)) 76 | } 77 | 78 | pub fn parse_response(reader: BaseStream, request: &PreparedRequest, url: &Url) -> Result { 79 | let mut reader = BufReader::new(reader); 80 | let (status, mut headers) = parse_response_head(&mut reader, request.base_settings.max_headers)?; 81 | let body_reader = BodyReader::new(&headers, reader)?; 82 | let compressed_reader = CompressedReader::new(&headers, request, body_reader)?; 83 | let response_reader = ResponseReader::new(&headers, request, compressed_reader); 84 | 85 | // Remove HOP-BY-HOP headers 86 | headers.remove(TRANSFER_ENCODING); 87 | 88 | Ok(Response { 89 | url: url.clone(), 90 | status, 91 | headers, 92 | reader: response_reader, 93 | }) 94 | } 95 | 96 | /// `Response` represents a response returned by a server. 97 | #[derive(Debug)] 98 | pub struct Response { 99 | url: Url, 100 | status: StatusCode, 101 | headers: HeaderMap, 102 | reader: ResponseReader, 103 | } 104 | 105 | impl Response { 106 | /// Get the final URL of this `Response`. 107 | #[inline] 108 | pub fn url(&self) -> &Url { 109 | &self.url 110 | } 111 | 112 | /// Get the status code of this `Response`. 113 | #[inline] 114 | pub fn status(&self) -> StatusCode { 115 | self.status 116 | } 117 | 118 | /// Get the headers of this `Response`. 119 | #[inline] 120 | pub fn headers(&self) -> &HeaderMap { 121 | &self.headers 122 | } 123 | 124 | /// Checks if the status code of this `Response` was a success code. 125 | #[inline] 126 | pub fn is_success(&self) -> bool { 127 | self.status.is_success() 128 | } 129 | 130 | /// Returns error variant if the status code was not a success code. 131 | pub fn error_for_status(self) -> Result { 132 | if self.is_success() { 133 | Ok(self) 134 | } else { 135 | Err(ErrorKind::StatusCode(self.status).into()) 136 | } 137 | } 138 | 139 | /// Split this `Response` into a tuple of `StatusCode`, `HeaderMap`, `ResponseReader`. 140 | /// 141 | /// This method is useful to read the status code or headers after consuming the response. 142 | #[inline] 143 | pub fn split(self) -> (StatusCode, HeaderMap, ResponseReader) { 144 | (self.status, self.headers, self.reader) 145 | } 146 | 147 | /// Write the response to any object that implements `Write`. 148 | #[inline] 149 | pub fn write_to(self, writer: W) -> Result 150 | where 151 | W: Write, 152 | { 153 | self.reader.write_to(writer) 154 | } 155 | 156 | /// Read the response to a `Vec` of bytes. 157 | #[inline] 158 | pub fn bytes(self) -> Result> { 159 | self.reader.bytes() 160 | } 161 | 162 | /// Read the response to a `String`. 163 | /// 164 | /// If the `charsets` feature is enabled, it will try to decode the response using 165 | /// the encoding in the headers. If there's no encoding specified in the headers, 166 | /// it will fall back to the default encoding, and if that's also not specified, 167 | /// it will fall back to the default of ISO-8859-1. 168 | /// 169 | /// If the `charsets` feature is disabled, this method is the same as calling 170 | /// `text_utf8`. 171 | /// 172 | /// Note that both conversions are lossy, i.e. they will not raise errors when 173 | /// invalid data is encountered but output replacement characters instead. 174 | #[inline] 175 | pub fn text(self) -> Result { 176 | self.reader.text() 177 | } 178 | 179 | /// Read the response to a `String`, decoding with the given `Charset`. 180 | /// 181 | /// This will ignore the encoding from the response headers and the default encoding, if any. 182 | /// 183 | /// This method only exists when the `charsets` feature is enabled. 184 | #[cfg(feature = "charsets")] 185 | #[inline] 186 | pub fn text_with(self, charset: Charset) -> Result { 187 | self.reader.text_with(charset) 188 | } 189 | 190 | /// Create a `TextReader` from this `ResponseReader`. 191 | /// 192 | /// If the response headers contain charset information, that charset will be used to decode the body. 193 | /// Otherwise, if a default encoding is set it will be used. If there is no default encoding, ISO-8859-1 194 | /// will be used. 195 | /// 196 | /// This method only exists when the `charsets` feature is enabled. 197 | #[cfg(feature = "charsets")] 198 | pub fn text_reader(self) -> TextReader> { 199 | self.reader.text_reader() 200 | } 201 | 202 | /// Create a `TextReader` from this `ResponseReader`, decoding with the given `Charset`. 203 | /// 204 | /// This will ignore the encoding from the response headers and the default encoding, if any. 205 | /// 206 | /// This method only exists when the `charsets` feature is enabled. 207 | #[cfg(feature = "charsets")] 208 | #[inline] 209 | pub fn text_reader_with(self, charset: Charset) -> TextReader> { 210 | self.reader.text_reader_with(charset) 211 | } 212 | 213 | /// Read the response body to a String using the UTF-8 encoding. 214 | /// 215 | /// This method ignores headers and the default encoding. 216 | /// 217 | /// Note that is lossy, i.e. it will not raise errors when 218 | /// invalid data is encountered but output replacement characters instead. 219 | #[inline] 220 | pub fn text_utf8(self) -> Result { 221 | self.reader.text_utf8() 222 | } 223 | 224 | /// Parse the response as a JSON object and return it. 225 | /// 226 | /// If the `charsets` feature is enabled, it will try to decode the response using 227 | /// the encoding in the headers. If there's no encoding specified in the headers, 228 | /// it will fall back to the default encoding, and if that's also not specified, 229 | /// it will fall back to the default of ISO-8859-1. 230 | /// 231 | /// If the `charsets` feature is disabled, this method is the same as calling 232 | /// `json_utf8`. 233 | #[cfg(feature = "json")] 234 | #[inline] 235 | pub fn json(self) -> Result 236 | where 237 | T: DeserializeOwned, 238 | { 239 | self.reader.json() 240 | } 241 | 242 | /// Parse the response as a JSON object encoded in UTF-8. 243 | /// 244 | /// This method ignores headers and the default encoding. 245 | /// 246 | /// This method only exists when the `json` feature is enabled. 247 | #[cfg(feature = "json")] 248 | #[inline] 249 | pub fn json_utf8(self) -> Result 250 | where 251 | T: DeserializeOwned, 252 | { 253 | self.reader.json_utf8() 254 | } 255 | } 256 | 257 | impl Read for Response { 258 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 259 | self.reader.read(buf) 260 | } 261 | } 262 | 263 | #[test] 264 | fn test_read_request_head() { 265 | let response = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nContent-Type: text/plain\r\n\r\nhello"; 266 | let mut reader = BufReader::new(&response[..]); 267 | let (status, headers) = parse_response_head(&mut reader, 100).unwrap(); 268 | assert_eq!(status, StatusCode::OK); 269 | assert_eq!(headers.len(), 2); 270 | assert_eq!(headers[http::header::CONTENT_LENGTH], "5"); 271 | assert_eq!(headers[http::header::CONTENT_TYPE], "text/plain"); 272 | } 273 | 274 | #[test] 275 | fn test_line_folded_header() { 276 | let response = b"HTTP/1.1 200 OK\r\nheader-of-great-many-lines: foo\nbar\nbaz\nqux\r\nthe-other-kind-of-header: foobar\r\n\r\n"; 277 | let mut reader = BufReader::new(&response[..]); 278 | let (status, headers) = parse_response_head(&mut reader, 100).unwrap(); 279 | assert_eq!(status, StatusCode::OK); 280 | assert_eq!(headers.len(), 2); 281 | assert_eq!(headers["header-of-great-many-lines"], "foo bar baz qux"); 282 | assert_eq!(headers["the-other-kind-of-header"], "foobar"); 283 | } 284 | 285 | #[test] 286 | fn test_max_headers_limit() { 287 | let response = b"HTTP/1.1 200 OK\r\nfirst-header: foo\r\nsecond-header: bar\r\none-header-too-many: baz\r\n\r\n"; 288 | let mut reader = BufReader::new(&response[..]); 289 | let err = parse_response_head(&mut reader, 2).unwrap_err(); 290 | assert!(matches!( 291 | err.kind(), 292 | ErrorKind::InvalidResponse(InvalidResponseKind::Header) 293 | )); 294 | } 295 | -------------------------------------------------------------------------------- /src/parsing/response_reader.rs: -------------------------------------------------------------------------------- 1 | #[cfg(any(feature = "charsets", feature = "json"))] 2 | use std::io::BufReader; 3 | use std::io::{self, Read, Write}; 4 | 5 | use http::header::HeaderMap; 6 | #[cfg(feature = "json")] 7 | use serde::de::DeserializeOwned; 8 | 9 | use crate::error::Result; 10 | use crate::parsing::compressed_reader::CompressedReader; 11 | use crate::request::PreparedRequest; 12 | 13 | #[cfg(feature = "charsets")] 14 | use { 15 | crate::{ 16 | charsets::{self, Charset}, 17 | parsing::buffers::trim_byte, 18 | parsing::TextReader, 19 | }, 20 | encoding_rs::Encoding, 21 | http::header::CONTENT_TYPE, 22 | }; 23 | 24 | #[cfg(feature = "charsets")] 25 | fn get_charset(headers: &HeaderMap, default_charset: Option) -> Charset { 26 | if let Some(value) = headers.get(CONTENT_TYPE) { 27 | let bytes = value.as_bytes(); 28 | if let Some(scol) = bytes.iter().position(|&b| b == b';') { 29 | let rhs = trim_byte(b' ', &bytes[scol + 1..]); 30 | if rhs.starts_with(b"charset=") { 31 | if let Some(enc) = Encoding::for_label(&rhs[8..]) { 32 | return enc; 33 | } 34 | } 35 | } 36 | } 37 | default_charset.unwrap_or(charsets::WINDOWS_1252) 38 | } 39 | 40 | /// The `ResponseReader` is used to read the body of a response. 41 | /// 42 | /// The `ResponseReader` implements `Read` and can be used like any other stream, 43 | /// but the data returned by `Read` are untouched bytes from the socket. This means 44 | /// that if a string is expected back, it could be in a different encoding than the 45 | /// expected one. In order to properly read text, use the `charsets` feature and the 46 | /// `text` or `text_reader` methods. 47 | /// 48 | /// In general it's best to avoid `Read`ing directly from this object. Instead use the 49 | /// helper methods, they process the data stream properly. 50 | #[derive(Debug)] 51 | pub struct ResponseReader { 52 | inner: CompressedReader, 53 | #[cfg(feature = "charsets")] 54 | charset: Charset, 55 | } 56 | 57 | impl ResponseReader { 58 | #[cfg(feature = "charsets")] 59 | pub(crate) fn new( 60 | headers: &HeaderMap, 61 | request: &PreparedRequest, 62 | reader: CompressedReader, 63 | ) -> ResponseReader { 64 | ResponseReader { 65 | inner: reader, 66 | charset: get_charset(headers, request.base_settings.default_charset), 67 | } 68 | } 69 | 70 | #[cfg(not(feature = "charsets"))] 71 | pub(crate) fn new(_: &HeaderMap, _: &PreparedRequest, reader: CompressedReader) -> ResponseReader { 72 | ResponseReader { inner: reader } 73 | } 74 | 75 | /// Write the response to any object that implements `Write`. 76 | pub fn write_to(mut self, mut writer: W) -> Result 77 | where 78 | W: Write, 79 | { 80 | let n = io::copy(&mut self.inner, &mut writer)?; 81 | Ok(n) 82 | } 83 | 84 | /// Read the response to a `Vec` of bytes. 85 | pub fn bytes(self) -> Result> { 86 | let mut buf = Vec::new(); 87 | self.write_to(&mut buf)?; 88 | Ok(buf) 89 | } 90 | 91 | /// Read the response to a `String`. 92 | /// 93 | /// If the `charsets` feature is enabled, it will try to decode the response using 94 | /// the encoding in the headers. If there's no encoding specified in the headers, 95 | /// it will fall back to the default encoding, and if that's also not specified, 96 | /// it will fall back to the default of ISO-8859-1. 97 | /// 98 | /// If the `charsets` feature is disabled, this method is the same as calling 99 | /// `text_utf8`. 100 | /// 101 | /// Note that both conversions are lossy, i.e. they will not raise errors when 102 | /// invalid data is encountered but output replacement characters instead. 103 | #[cfg(not(feature = "charsets"))] 104 | pub fn text(self) -> Result { 105 | self.text_utf8() 106 | } 107 | 108 | /// Read the response to a `String`. 109 | /// 110 | /// If the `charsets` feature is enabled, it will try to decode the response using 111 | /// the encoding in the headers. If there's no encoding specified in the headers, 112 | /// it will fall back to the default encoding, and if that's also not specified, 113 | /// it will fall back to the default of ISO-8859-1. 114 | /// 115 | /// If the `charsets` feature is disabled, this method is the same as calling 116 | /// `text_utf8`. 117 | /// 118 | /// Note that both conversions are lossy, i.e. they will not raise errors when 119 | /// invalid data is encountered but output replacement characters instead. 120 | #[cfg(feature = "charsets")] 121 | pub fn text(self) -> Result { 122 | let charset = self.charset; 123 | self.text_with(charset) 124 | } 125 | 126 | /// Read the response to a `String`, decoding with the given `Charset`. 127 | /// 128 | /// This will ignore the encoding from the response headers and the default encoding, if any. 129 | /// 130 | /// This method only exists when the `charsets` feature is enabled. 131 | #[cfg(feature = "charsets")] 132 | pub fn text_with(self, charset: Charset) -> Result { 133 | let mut reader = self.text_reader_with(charset); 134 | let mut text = String::new(); 135 | reader.read_to_string(&mut text)?; 136 | Ok(text) 137 | } 138 | 139 | /// Create a `TextReader` from this `ResponseReader`. 140 | /// 141 | /// If the response headers contain charset information, that charset will be used to decode the body. 142 | /// Otherwise, if a default encoding is set it will be used. If there is no default encoding, ISO-8859-1 143 | /// will be used. 144 | /// 145 | /// This method only exists when the `charsets` feature is enabled. 146 | #[cfg(feature = "charsets")] 147 | pub fn text_reader(self) -> TextReader> { 148 | let charset = self.charset; 149 | self.text_reader_with(charset) 150 | } 151 | 152 | /// Create a `TextReader` from this `ResponseReader`, decoding with the given `Charset`. 153 | /// 154 | /// This will ignore the encoding from the response headers and the default encoding, if any. 155 | /// 156 | /// This method only exists when the `charsets` feature is enabled. 157 | #[cfg(feature = "charsets")] 158 | pub fn text_reader_with(self, charset: Charset) -> TextReader> { 159 | TextReader::new(BufReader::new(self), charset) 160 | } 161 | 162 | /// Read the response body to a String using the UTF-8 encoding. 163 | /// 164 | /// This method ignores headers and the default encoding. 165 | /// 166 | /// Note that this is lossy, i.e. it will not raise errors when 167 | /// invalid data is encountered but output replacement characters instead. 168 | pub fn text_utf8(mut self) -> Result { 169 | let mut buf = Vec::new(); 170 | self.inner.read_to_end(&mut buf)?; 171 | 172 | let text = String::from_utf8(buf).unwrap_or_else(|err| String::from_utf8_lossy(err.as_bytes()).into_owned()); 173 | 174 | Ok(text) 175 | } 176 | 177 | /// Parse the response as a JSON object and return it. 178 | /// 179 | /// If the `charsets` feature is enabled, it will try to decode the response using 180 | /// the encoding in the headers. If there's no encoding specified in the headers, 181 | /// it will fall back to the default encoding, and if that's also not specified, 182 | /// it will fall back to the default of ISO-8859-1. 183 | /// 184 | /// If the `charsets` feature is disabled, this method is the same as calling 185 | /// `json_utf8`. 186 | #[cfg(feature = "json")] 187 | #[cfg(feature = "charsets")] 188 | pub fn json(self) -> Result 189 | where 190 | T: DeserializeOwned, 191 | { 192 | let reader = BufReader::new(self.text_reader()); 193 | let obj = serde_json::from_reader(reader)?; 194 | Ok(obj) 195 | } 196 | 197 | /// Parse the response as a JSON object and return it. 198 | /// 199 | /// If the `charsets` feature is enabled, it will try to decode the response using 200 | /// the encoding in the headers. If there's no encoding specified in the headers, 201 | /// it will fall back to the default encoding, and if that's also not specified, 202 | /// it will fall back to the default of ISO-8859-1. 203 | /// 204 | /// If the `charsets` feature is disabled, this method is the same as calling 205 | /// `json_utf8`. 206 | #[cfg(feature = "json")] 207 | #[cfg(not(feature = "charsets"))] 208 | pub fn json(self) -> Result 209 | where 210 | T: DeserializeOwned, 211 | { 212 | self.json_utf8() 213 | } 214 | 215 | /// Parse the response as a JSON object encoded in UTF-8. 216 | /// 217 | /// This method ignores headers and the default encoding. 218 | /// 219 | /// This method only exists when the `json` feature is enabled. 220 | #[cfg(feature = "json")] 221 | pub fn json_utf8(self) -> Result 222 | where 223 | T: DeserializeOwned, 224 | { 225 | let reader = BufReader::new(self); 226 | let obj = serde_json::from_reader(reader)?; 227 | Ok(obj) 228 | } 229 | } 230 | 231 | impl Read for ResponseReader { 232 | #[inline] 233 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 234 | self.inner.read(buf) 235 | } 236 | } 237 | 238 | #[cfg(test)] 239 | #[cfg(feature = "charsets")] 240 | mod tests { 241 | use http::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; 242 | 243 | use super::get_charset; 244 | use crate::charsets; 245 | 246 | #[test] 247 | fn test_get_charset_from_header() { 248 | let mut headers = HeaderMap::new(); 249 | headers.insert( 250 | CONTENT_TYPE, 251 | HeaderValue::from_bytes(&b"text/html; charset=UTF-8"[..]).unwrap(), 252 | ); 253 | assert_eq!(get_charset(&headers, None), charsets::UTF_8); 254 | } 255 | 256 | #[test] 257 | fn test_get_charset_from_header_lowercase() { 258 | let mut headers = HeaderMap::new(); 259 | headers.insert( 260 | CONTENT_TYPE, 261 | HeaderValue::from_bytes(&b"text/html; charset=utf8"[..]).unwrap(), 262 | ); 263 | assert_eq!(get_charset(&headers, None), charsets::UTF_8); 264 | } 265 | 266 | #[test] 267 | fn test_get_charset_from_default() { 268 | let headers = HeaderMap::new(); 269 | assert_eq!(get_charset(&headers, Some(charsets::UTF_8)), charsets::UTF_8); 270 | } 271 | 272 | #[test] 273 | fn test_get_charset_standard() { 274 | let headers = HeaderMap::new(); 275 | assert_eq!(get_charset(&headers, None), charsets::WINDOWS_1252); 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /src/parsing/text_reader.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read}; 2 | 3 | use encoding_rs_io::{DecodeReaderBytes, DecodeReaderBytesBuilder}; 4 | 5 | use crate::charsets::Charset; 6 | 7 | /// `TextReader` converts bytes in a specific charset to bytes in UTF-8. 8 | /// 9 | /// It can be used to convert a stream of text in a specific charset into a stream 10 | /// of UTF-8 encoded bytes. The `Read::read_to_string` method can be used to convert 11 | /// the stream of UTF-8 bytes into a `String`. 12 | #[derive(Debug)] 13 | pub struct TextReader(DecodeReaderBytes>); 14 | 15 | impl TextReader 16 | where 17 | R: Read, 18 | { 19 | /// Create a new `TextReader` with the given charset. 20 | pub fn new(inner: R, charset: Charset) -> Self { 21 | Self(DecodeReaderBytesBuilder::new().encoding(Some(charset)).build(inner)) 22 | } 23 | } 24 | 25 | impl Read for TextReader 26 | where 27 | R: Read, 28 | { 29 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 30 | self.0.read(buf) 31 | } 32 | } 33 | 34 | #[test] 35 | fn test_stream_decoder_utf8() { 36 | let mut reader = TextReader::new("québec".as_bytes(), crate::charsets::UTF_8); 37 | 38 | let mut text = String::new(); 39 | assert_eq!(reader.read_to_string(&mut text).ok(), Some(7)); 40 | 41 | assert_eq!(text, "québec"); 42 | } 43 | 44 | #[test] 45 | fn test_stream_decoder_latin1() { 46 | let mut reader = TextReader::new(&b"qu\xC9bec"[..], crate::charsets::WINDOWS_1252); 47 | 48 | let mut text = String::new(); 49 | assert_eq!(reader.read_to_string(&mut text).ok(), Some(7)); 50 | 51 | assert_eq!(text, "quÉbec"); 52 | } 53 | 54 | #[test] 55 | fn test_string_reader_large_buffer_latin1() { 56 | let buf = vec![201; 10_000]; 57 | let mut reader = TextReader::new(&buf[..], crate::charsets::WINDOWS_1252); 58 | 59 | let mut text = String::new(); 60 | assert_eq!(20_000, reader.read_to_string(&mut text).unwrap()); 61 | assert_eq!(text.len(), 20_000); 62 | 63 | for c in text.chars() { 64 | assert_eq!(c, 'É'); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/request/body.rs: -------------------------------------------------------------------------------- 1 | use std::convert::TryInto; 2 | use std::fs; 3 | use std::io::{copy, Result as IoResult, Seek, SeekFrom, Write}; 4 | 5 | /// The kinds of request bodies currently supported by this crate. 6 | #[derive(Debug, Clone, Copy)] 7 | pub enum BodyKind { 8 | /// An empty request body 9 | Empty, 10 | /// A request body with a known length 11 | KnownLength(u64), 12 | /// A request body that is transferred using chunked encoding 13 | Chunked, 14 | } 15 | 16 | /// A generic rewindable request body 17 | pub trait Body { 18 | /// Determine the kind of the request body 19 | fn kind(&mut self) -> IoResult; 20 | 21 | /// Write out the request body into the given writer 22 | /// 23 | /// This method can be called multiple times if a request is redirected. 24 | fn write(&mut self, writer: W) -> IoResult<()>; 25 | 26 | /// Gets the content type this body is tied to if it has one. 27 | fn content_type(&mut self) -> IoResult> { 28 | Ok(None) 29 | } 30 | } 31 | 32 | /// An empty request body 33 | #[derive(Debug, Clone, Copy)] 34 | pub struct Empty; 35 | 36 | impl Body for Empty { 37 | fn kind(&mut self) -> IoResult { 38 | Ok(BodyKind::Empty) 39 | } 40 | 41 | fn write(&mut self, _writer: W) -> IoResult<()> { 42 | Ok(()) 43 | } 44 | } 45 | 46 | /// A request body containing UTF-8-encoded text 47 | #[derive(Debug, Clone)] 48 | pub struct Text(pub B); 49 | 50 | impl> Body for Text { 51 | fn kind(&mut self) -> IoResult { 52 | let len = self.0.as_ref().len().try_into().unwrap(); 53 | Ok(BodyKind::KnownLength(len)) 54 | } 55 | 56 | fn write(&mut self, mut writer: W) -> IoResult<()> { 57 | writer.write_all(self.0.as_ref().as_bytes()) 58 | } 59 | } 60 | 61 | /// A request body containing binary data 62 | #[derive(Debug, Clone)] 63 | pub struct Bytes(pub B); 64 | 65 | impl> Body for Bytes { 66 | fn kind(&mut self) -> IoResult { 67 | let len = self.0.as_ref().len().try_into().unwrap(); 68 | Ok(BodyKind::KnownLength(len)) 69 | } 70 | 71 | fn write(&mut self, mut writer: W) -> IoResult<()> { 72 | writer.write_all(self.0.as_ref()) 73 | } 74 | } 75 | 76 | /// A request body backed by a local file 77 | #[derive(Debug)] 78 | pub struct File(pub fs::File); 79 | 80 | impl Body for File { 81 | fn kind(&mut self) -> IoResult { 82 | let len = self.0.seek(SeekFrom::End(0))?; 83 | Ok(BodyKind::KnownLength(len)) 84 | } 85 | 86 | fn write(&mut self, mut writer: W) -> IoResult<()> { 87 | self.0.rewind()?; 88 | copy(&mut self.0, &mut writer)?; 89 | Ok(()) 90 | } 91 | } 92 | 93 | pub(crate) struct ChunkedWriter(pub W); 94 | 95 | impl ChunkedWriter { 96 | pub fn close(mut self) -> IoResult<()> { 97 | self.0.write_all(b"0\r\n\r\n") 98 | } 99 | } 100 | 101 | impl Write for ChunkedWriter { 102 | fn write(&mut self, buf: &[u8]) -> IoResult { 103 | write!(self.0, "{:x}\r\n", buf.len())?; 104 | self.0.write_all(buf)?; 105 | write!(self.0, "\r\n")?; 106 | Ok(buf.len()) 107 | } 108 | 109 | fn flush(&mut self) -> IoResult<()> { 110 | self.0.flush() 111 | } 112 | } 113 | 114 | #[cfg(feature = "json")] 115 | mod json { 116 | use super::*; 117 | 118 | use std::io::BufWriter; 119 | 120 | use serde::ser::Serialize; 121 | use serde_json::ser::to_writer; 122 | 123 | /// A request body for streaming out JSON 124 | #[derive(Debug, Clone)] 125 | pub struct Json(pub B); 126 | 127 | impl Body for Json { 128 | fn kind(&mut self) -> IoResult { 129 | Ok(BodyKind::Chunked) 130 | } 131 | 132 | fn write(&mut self, writer: W) -> IoResult<()> { 133 | let mut writer = BufWriter::new(writer); 134 | to_writer(&mut writer, &self.0)?; 135 | writer.flush()?; 136 | Ok(()) 137 | } 138 | } 139 | } 140 | 141 | #[cfg(feature = "json")] 142 | pub use json::Json; 143 | -------------------------------------------------------------------------------- /src/request/mod.rs: -------------------------------------------------------------------------------- 1 | use std::convert::{From, TryInto}; 2 | use std::io::{prelude::*, BufWriter}; 3 | use std::str; 4 | use std::sync::Arc; 5 | use std::time::Instant; 6 | 7 | #[cfg(feature = "flate2")] 8 | use http::header::ACCEPT_ENCODING; 9 | use http::{ 10 | header::{HeaderValue, IntoHeaderName, HOST}, 11 | HeaderMap, Method, StatusCode, Version, 12 | }; 13 | use url::Url; 14 | 15 | use crate::error::{Error, ErrorKind, InvalidResponseKind, Result}; 16 | use crate::parsing::{parse_response, Response}; 17 | use crate::streams::{BaseStream, ConnectInfo}; 18 | 19 | /// Contains types to describe request bodies 20 | pub mod body; 21 | mod builder; 22 | pub mod proxy; 23 | mod session; 24 | mod settings; 25 | 26 | use body::{Body, BodyKind}; 27 | pub use builder::{RequestBuilder, RequestInspector}; 28 | pub use session::Session; 29 | pub(crate) use settings::BaseSettings; 30 | 31 | fn header_insert(headers: &mut HeaderMap, header: H, value: V) -> Result 32 | where 33 | H: IntoHeaderName, 34 | V: TryInto, 35 | Error: From, 36 | { 37 | let value = value.try_into()?; 38 | headers.insert(header, value); 39 | Ok(()) 40 | } 41 | 42 | fn header_insert_if_missing(headers: &mut HeaderMap, header: H, value: V) -> Result 43 | where 44 | H: IntoHeaderName, 45 | V: TryInto, 46 | Error: From, 47 | { 48 | let value = value.try_into()?; 49 | headers.entry(header).or_insert(value); 50 | Ok(()) 51 | } 52 | 53 | fn header_append(headers: &mut HeaderMap, header: H, value: V) -> Result 54 | where 55 | H: IntoHeaderName, 56 | V: TryInto, 57 | Error: From, 58 | { 59 | let value = value.try_into()?; 60 | headers.append(header, value); 61 | Ok(()) 62 | } 63 | 64 | /// Represents a request that's ready to be sent. You can inspect this object for information about the request. 65 | #[derive(Debug)] 66 | pub struct PreparedRequest { 67 | url: Url, 68 | method: Method, 69 | body: B, 70 | headers: HeaderMap, 71 | pub(crate) base_settings: Arc, 72 | } 73 | 74 | #[cfg(test)] 75 | impl PreparedRequest { 76 | pub(crate) fn new(method: Method, base_url: U) -> Self 77 | where 78 | U: AsRef, 79 | { 80 | PreparedRequest { 81 | url: Url::parse(base_url.as_ref()).unwrap(), 82 | method, 83 | body: body::Empty, 84 | headers: HeaderMap::new(), 85 | base_settings: Arc::new(BaseSettings::default()), 86 | } 87 | } 88 | } 89 | 90 | impl PreparedRequest { 91 | #[cfg(not(feature = "flate2"))] 92 | fn set_compression(&mut self) -> Result { 93 | Ok(()) 94 | } 95 | 96 | #[cfg(feature = "flate2")] 97 | fn set_compression(&mut self) -> Result { 98 | if self.base_settings.allow_compression { 99 | header_insert(&mut self.headers, ACCEPT_ENCODING, "gzip, deflate")?; 100 | } 101 | Ok(()) 102 | } 103 | 104 | fn base_redirect_url(&self, location: &str, previous_url: &Url) -> Result { 105 | match Url::parse(location) { 106 | Ok(url) => Ok(url), 107 | Err(url::ParseError::RelativeUrlWithoutBase) => { 108 | let joined_url = previous_url 109 | .join(location) 110 | .map_err(|_| InvalidResponseKind::RedirectionUrl)?; 111 | 112 | Ok(joined_url) 113 | } 114 | Err(_) => Err(InvalidResponseKind::RedirectionUrl.into()), 115 | } 116 | } 117 | 118 | fn write_headers(&self, writer: &mut W) -> Result 119 | where 120 | W: Write, 121 | { 122 | for (key, value) in self.headers.iter() { 123 | write!(writer, "{}: ", key.as_str())?; 124 | writer.write_all(value.as_bytes())?; 125 | write!(writer, "\r\n")?; 126 | } 127 | write!(writer, "\r\n")?; 128 | Ok(()) 129 | } 130 | 131 | /// Get the URL of this request. 132 | pub fn url(&self) -> &Url { 133 | &self.url 134 | } 135 | 136 | /// Get the method of this request. 137 | pub fn method(&self) -> &Method { 138 | &self.method 139 | } 140 | 141 | /// Get the body of the request. 142 | pub fn body(&self) -> &B { 143 | &self.body 144 | } 145 | 146 | /// Get the headers of this request. 147 | pub fn headers(&self) -> &HeaderMap { 148 | &self.headers 149 | } 150 | } 151 | 152 | impl PreparedRequest { 153 | fn write_request(&mut self, writer: W, url: &Url, proxy: Option<&Url>) -> Result 154 | where 155 | W: Write, 156 | { 157 | let mut writer = BufWriter::new(writer); 158 | let version = Version::HTTP_11; 159 | 160 | if proxy.is_some() && url.scheme() == "http" { 161 | debug!("{} {} {:?}", self.method.as_str(), url, version); 162 | 163 | write!(writer, "{} {} {:?}\r\n", self.method.as_str(), url, version)?; 164 | } else if let Some(query) = url.query() { 165 | debug!("{} {}?{} {:?}", self.method.as_str(), url.path(), query, version); 166 | 167 | write!( 168 | writer, 169 | "{} {}?{} {:?}\r\n", 170 | self.method.as_str(), 171 | url.path(), 172 | query, 173 | version, 174 | )?; 175 | } else { 176 | debug!("{} {} {:?}", self.method.as_str(), url.path(), version); 177 | 178 | write!(writer, "{} {} {:?}\r\n", self.method.as_str(), url.path(), version)?; 179 | } 180 | 181 | self.write_headers(&mut writer)?; 182 | 183 | match self.body.kind()? { 184 | BodyKind::Empty => (), 185 | BodyKind::KnownLength(len) => { 186 | debug!("writing out body of length {}", len); 187 | self.body.write(&mut writer)?; 188 | } 189 | BodyKind::Chunked => { 190 | debug!("writing out chunked body"); 191 | let mut writer = body::ChunkedWriter(&mut writer); 192 | self.body.write(&mut writer)?; 193 | writer.close()?; 194 | } 195 | } 196 | 197 | writer.flush()?; 198 | 199 | Ok(()) 200 | } 201 | 202 | /// Send this request and wait for the result. 203 | pub fn send(&mut self) -> Result { 204 | let mut url = self.url.clone(); 205 | 206 | let deadline = self.base_settings.timeout.map(|timeout| Instant::now() + timeout); 207 | let mut redirections = 0; 208 | 209 | loop { 210 | // If a proxy is set and the url is using http, we must connect to the proxy and send 211 | // a request with an authority instead of a path. 212 | // 213 | // If a proxy is set and the url is using https, we must connect to the proxy using 214 | // the CONNECT method, and then send https traffic on the socket after the CONNECT 215 | // handshake. 216 | 217 | let proxy = self.base_settings.proxy_settings.for_url(&url).cloned(); 218 | 219 | // If there is a proxy and the protocol is HTTP, the Host header will be the proxy's host name. 220 | match (url.scheme(), &proxy) { 221 | ("http", Some(proxy)) => set_host(&mut self.headers, proxy)?, 222 | _ => set_host(&mut self.headers, &url)?, 223 | }; 224 | 225 | let info = ConnectInfo { 226 | url: &url, 227 | proxy: proxy.as_ref(), 228 | base_settings: &self.base_settings, 229 | deadline, 230 | }; 231 | let mut stream = BaseStream::connect(&info)?; 232 | 233 | self.write_request(&mut stream, &url, proxy.as_ref())?; 234 | let resp = parse_response(stream, self, &url)?; 235 | 236 | debug!("status code {}", resp.status().as_u16()); 237 | 238 | let is_redirect = matches!( 239 | resp.status(), 240 | StatusCode::MOVED_PERMANENTLY 241 | | StatusCode::FOUND 242 | | StatusCode::SEE_OTHER 243 | | StatusCode::TEMPORARY_REDIRECT 244 | | StatusCode::PERMANENT_REDIRECT 245 | ); 246 | if !self.base_settings.follow_redirects || !is_redirect { 247 | return Ok(resp); 248 | } 249 | 250 | redirections += 1; 251 | if redirections > self.base_settings.max_redirections { 252 | return Err(ErrorKind::TooManyRedirections.into()); 253 | } 254 | 255 | // Handle redirect 256 | let location = resp 257 | .headers() 258 | .get(http::header::LOCATION) 259 | .ok_or(InvalidResponseKind::LocationHeader)?; 260 | 261 | let location = String::from_utf8_lossy(location.as_bytes()); 262 | 263 | url = self.base_redirect_url(&location, &url)?; 264 | 265 | debug!("redirected to {} giving url {}", location, url); 266 | } 267 | } 268 | } 269 | 270 | fn set_host(headers: &mut HeaderMap, url: &Url) -> Result { 271 | let host = url.host_str().ok_or(ErrorKind::InvalidUrlHost)?; 272 | if let Some(port) = url.port() { 273 | header_insert(headers, HOST, format!("{host}:{port}"))?; 274 | } else { 275 | header_insert(headers, HOST, host)?; 276 | } 277 | Ok(()) 278 | } 279 | 280 | #[cfg(test)] 281 | mod test { 282 | use std::sync::Arc; 283 | 284 | use http::header::{HeaderMap, HeaderValue, USER_AGENT}; 285 | use http::Method; 286 | use url::Url; 287 | 288 | use super::BaseSettings; 289 | use super::{header_append, header_insert, header_insert_if_missing, PreparedRequest}; 290 | use crate::body::Empty; 291 | 292 | #[test] 293 | fn test_header_insert_exists() { 294 | let mut headers = HeaderMap::new(); 295 | headers.insert(USER_AGENT, HeaderValue::from_static("hello")); 296 | header_insert(&mut headers, USER_AGENT, "world").unwrap(); 297 | assert_eq!(headers[USER_AGENT], "world"); 298 | } 299 | 300 | #[test] 301 | fn test_header_insert_missing() { 302 | let mut headers = HeaderMap::new(); 303 | header_insert(&mut headers, USER_AGENT, "world").unwrap(); 304 | assert_eq!(headers[USER_AGENT], "world"); 305 | } 306 | 307 | #[test] 308 | fn test_header_insert_if_missing_exists() { 309 | let mut headers = HeaderMap::new(); 310 | headers.insert(USER_AGENT, HeaderValue::from_static("hello")); 311 | header_insert_if_missing(&mut headers, USER_AGENT, "world").unwrap(); 312 | assert_eq!(headers[USER_AGENT], "hello"); 313 | } 314 | 315 | #[test] 316 | fn test_header_insert_if_missing_missing() { 317 | let mut headers = HeaderMap::new(); 318 | header_insert_if_missing(&mut headers, USER_AGENT, "world").unwrap(); 319 | assert_eq!(headers[USER_AGENT], "world"); 320 | } 321 | 322 | #[test] 323 | fn test_header_append() { 324 | let mut headers = HeaderMap::new(); 325 | header_append(&mut headers, USER_AGENT, "hello").unwrap(); 326 | header_append(&mut headers, USER_AGENT, "world").unwrap(); 327 | 328 | let vals: Vec<_> = headers.get_all(USER_AGENT).into_iter().collect(); 329 | assert_eq!(vals.len(), 2); 330 | for val in vals { 331 | assert!(val == "hello" || val == "world"); 332 | } 333 | } 334 | 335 | #[test] 336 | fn test_http_url_with_http_proxy() { 337 | let mut req = PreparedRequest { 338 | method: Method::GET, 339 | url: Url::parse("http://reddit.com/r/rust").unwrap(), 340 | body: Empty, 341 | headers: HeaderMap::new(), 342 | base_settings: Arc::new(BaseSettings::default()), 343 | }; 344 | 345 | let proxy = Url::parse("http://proxy:3128").unwrap(); 346 | let mut buf: Vec = vec![]; 347 | req.write_request(&mut buf, &req.url.clone(), Some(&proxy)).unwrap(); 348 | 349 | let text = std::str::from_utf8(&buf).unwrap(); 350 | let lines: Vec<_> = text.split("\r\n").collect(); 351 | 352 | assert_eq!(lines[0], "GET http://reddit.com/r/rust HTTP/1.1"); 353 | } 354 | 355 | #[test] 356 | fn test_http_url_with_https_proxy() { 357 | let mut req = PreparedRequest { 358 | method: Method::GET, 359 | url: Url::parse("http://reddit.com/r/rust").unwrap(), 360 | body: Empty, 361 | headers: HeaderMap::new(), 362 | base_settings: Arc::new(BaseSettings::default()), 363 | }; 364 | 365 | let proxy = Url::parse("http://proxy:3128").unwrap(); 366 | let mut buf: Vec = vec![]; 367 | req.write_request(&mut buf, &req.url.clone(), Some(&proxy)).unwrap(); 368 | 369 | let text = std::str::from_utf8(&buf).unwrap(); 370 | let lines: Vec<_> = text.split("\r\n").collect(); 371 | 372 | assert_eq!(lines[0], "GET http://reddit.com/r/rust HTTP/1.1"); 373 | } 374 | } 375 | -------------------------------------------------------------------------------- /src/request/proxy.rs: -------------------------------------------------------------------------------- 1 | use std::{env, vec}; 2 | 3 | use url::Url; 4 | 5 | fn get_env(name: &str) -> Option { 6 | match env::var(name.to_ascii_lowercase()).or_else(|_| env::var(name.to_ascii_uppercase())) { 7 | Ok(s) => Some(s), 8 | Err(env::VarError::NotPresent) => None, 9 | Err(env::VarError::NotUnicode(_)) => { 10 | warn!( 11 | "Environment variable {} contains non-unicode characters", 12 | name.to_ascii_uppercase() 13 | ); 14 | None 15 | } 16 | } 17 | } 18 | 19 | fn get_env_url(name: &str) -> Option { 20 | match get_env(name) { 21 | Some(val) if val.trim().is_empty() => None, 22 | Some(val) => match Url::parse(&val) { 23 | Ok(url) => match url.scheme() { 24 | "http" | "https" => Some(url), 25 | _ => { 26 | warn!( 27 | "Environment variable {} contains unsupported proxy scheme: {}", 28 | name.to_ascii_uppercase(), 29 | url.scheme() 30 | ); 31 | None 32 | } 33 | }, 34 | Err(err) => { 35 | warn!( 36 | "Environment variable {} contains invalid URL: {}", 37 | name.to_ascii_uppercase(), 38 | err 39 | ); 40 | None 41 | } 42 | }, 43 | None => None, 44 | } 45 | } 46 | 47 | /// Contains proxy settings and utilities to find which proxy to use for a given URL. 48 | #[derive(Clone, Debug)] 49 | pub struct ProxySettings { 50 | http_proxy: Option, 51 | https_proxy: Option, 52 | disable_proxies: bool, 53 | no_proxy_hosts: Vec, 54 | } 55 | 56 | impl ProxySettings { 57 | /// Get a new builder for ProxySettings. 58 | pub fn builder() -> ProxySettingsBuilder { 59 | ProxySettingsBuilder::new() 60 | } 61 | 62 | /// Get the proxy configuration from the environment using the `curl`/Unix proxy conventions. 63 | /// 64 | /// Only `ALL_PROXY`, `HTTP_PROXY`, `HTTPS_PROXY` and `NO_PROXY` are supported. 65 | /// Proxies can be disabled on all requests by setting `NO_PROXY` to `*`, similar to `curl`. 66 | /// `HTTP_PROXY` or `HTTPS_PROXY` take precedence over values set by `ALL_PROXY` for their 67 | /// respective schemes. 68 | /// 69 | /// See 70 | pub fn from_env() -> ProxySettings { 71 | let all_proxy = get_env_url("all_proxy"); 72 | let http_proxy = get_env_url("http_proxy"); 73 | let https_proxy = get_env_url("https_proxy"); 74 | let no_proxy = get_env("no_proxy"); 75 | 76 | let disable_proxies = no_proxy.as_deref().unwrap_or("") == "*"; 77 | let mut no_proxy_hosts = vec![]; 78 | 79 | if !disable_proxies { 80 | if let Some(no_proxy) = no_proxy { 81 | no_proxy_hosts.extend( 82 | no_proxy 83 | .split(',') 84 | .map(|s| s.trim().trim_start_matches('.').to_lowercase()), 85 | ); 86 | } 87 | } 88 | 89 | ProxySettings { 90 | http_proxy: http_proxy.or_else(|| all_proxy.clone()), 91 | https_proxy: https_proxy.or(all_proxy), 92 | disable_proxies, 93 | no_proxy_hosts, 94 | } 95 | } 96 | 97 | /// Get the proxy URL to use for the given URL. 98 | /// 99 | /// None is returned if there is no proxy configured for the scheme or if the hostname 100 | /// matches a pattern in the no proxy list. 101 | pub fn for_url(&self, url: &Url) -> Option<&Url> { 102 | if self.disable_proxies { 103 | return None; 104 | } 105 | 106 | if let Some(host) = url.host_str() { 107 | if !self 108 | .no_proxy_hosts 109 | .iter() 110 | .any(|x| host.ends_with(x.to_lowercase().as_str())) 111 | { 112 | return match url.scheme() { 113 | "http" => self.http_proxy.as_ref(), 114 | "https" => self.https_proxy.as_ref(), 115 | _ => None, 116 | }; 117 | } 118 | } 119 | None 120 | } 121 | } 122 | 123 | /// Utility to build ProxySettings easily. 124 | #[derive(Clone, Debug)] 125 | pub struct ProxySettingsBuilder { 126 | inner: ProxySettings, 127 | } 128 | 129 | impl ProxySettingsBuilder { 130 | /// Create a new ProxySetting builder with no initial configuration. 131 | pub fn new() -> Self { 132 | ProxySettingsBuilder { 133 | inner: ProxySettings { 134 | http_proxy: None, 135 | https_proxy: None, 136 | disable_proxies: false, 137 | no_proxy_hosts: vec![], 138 | }, 139 | } 140 | } 141 | 142 | /// Set the proxy for http requests. 143 | pub fn http_proxy(mut self, val: V) -> Self 144 | where 145 | V: Into>, 146 | { 147 | self.inner.http_proxy = val.into(); 148 | self 149 | } 150 | 151 | /// Set the proxy for https requests. 152 | pub fn https_proxy(mut self, val: V) -> Self 153 | where 154 | V: Into>, 155 | { 156 | self.inner.https_proxy = val.into(); 157 | self 158 | } 159 | 160 | /// Add a hostname pattern to ignore when finding the proxy to use for a URL. 161 | /// 162 | /// For instance `mycompany.local` will make requests with the hostname `mycompany.local` 163 | /// not go trough the proxy. 164 | pub fn add_no_proxy_host(mut self, pattern: impl AsRef) -> Self { 165 | self.inner.no_proxy_hosts.push(pattern.as_ref().to_lowercase()); 166 | self 167 | } 168 | 169 | /// Build the settings. 170 | pub fn build(self) -> ProxySettings { 171 | self.inner 172 | } 173 | } 174 | 175 | impl Default for ProxySettingsBuilder { 176 | fn default() -> Self { 177 | ProxySettingsBuilder::new() 178 | } 179 | } 180 | 181 | #[test] 182 | fn test_proxy_for_url() { 183 | let s = ProxySettings { 184 | http_proxy: Some("http://proxy1:3128".parse().unwrap()), 185 | https_proxy: Some("http://proxy2:3128".parse().unwrap()), 186 | disable_proxies: false, 187 | no_proxy_hosts: vec!["reddit.com".into()], 188 | }; 189 | 190 | assert_eq!( 191 | s.for_url(&Url::parse("http://google.ca").unwrap()), 192 | Some(&"http://proxy1:3128".parse().unwrap()) 193 | ); 194 | 195 | assert_eq!( 196 | s.for_url(&Url::parse("https://google.ca").unwrap()), 197 | Some(&"http://proxy2:3128".parse().unwrap()) 198 | ); 199 | 200 | assert_eq!(s.for_url(&Url::parse("https://reddit.com").unwrap()), None); 201 | } 202 | 203 | #[test] 204 | fn test_proxy_for_url_disabled() { 205 | let s = ProxySettings { 206 | http_proxy: Some("http://proxy1:3128".parse().unwrap()), 207 | https_proxy: Some("http://proxy2:3128".parse().unwrap()), 208 | disable_proxies: true, 209 | no_proxy_hosts: vec![], 210 | }; 211 | 212 | assert_eq!(s.for_url(&Url::parse("https://reddit.com").unwrap()), None); 213 | assert_eq!(s.for_url(&Url::parse("https://www.google.ca").unwrap()), None); 214 | } 215 | 216 | #[cfg(test)] 217 | fn with_reset_proxy_vars(test: T) 218 | where 219 | T: FnOnce() + std::panic::UnwindSafe, 220 | { 221 | use std::sync::Mutex; 222 | 223 | lazy_static::lazy_static! { 224 | static ref LOCK: Mutex<()> = Mutex::new(()); 225 | }; 226 | 227 | let _guard = LOCK.lock().unwrap(); 228 | 229 | env::remove_var("ALL_PROXY"); 230 | env::remove_var("HTTP_PROXY"); 231 | env::remove_var("HTTPS_PROXY"); 232 | env::remove_var("NO_PROXY"); 233 | 234 | let result = std::panic::catch_unwind(test); 235 | 236 | // teardown if ever needed 237 | 238 | if let Err(ctx) = result { 239 | std::panic::resume_unwind(ctx); 240 | } 241 | } 242 | 243 | #[test] 244 | fn test_proxy_from_env_all_proxy() { 245 | with_reset_proxy_vars(|| { 246 | env::set_var("ALL_PROXY", "http://proxy:3128"); 247 | 248 | let s = ProxySettings::from_env(); 249 | 250 | assert_eq!(s.http_proxy.unwrap().as_str(), "http://proxy:3128/"); 251 | assert_eq!(s.https_proxy.unwrap().as_str(), "http://proxy:3128/"); 252 | }); 253 | } 254 | 255 | #[test] 256 | fn test_proxy_from_env_override() { 257 | with_reset_proxy_vars(|| { 258 | env::set_var("ALL_PROXY", "http://proxy:3128"); 259 | env::set_var("HTTP_PROXY", "http://proxy:3129"); 260 | env::set_var("HTTPS_PROXY", "http://proxy:3130"); 261 | 262 | let s = ProxySettings::from_env(); 263 | 264 | assert_eq!(s.http_proxy.unwrap().as_str(), "http://proxy:3129/"); 265 | assert_eq!(s.https_proxy.unwrap().as_str(), "http://proxy:3130/"); 266 | }); 267 | } 268 | 269 | #[test] 270 | fn test_proxy_from_env_no_proxy_wildcard() { 271 | with_reset_proxy_vars(|| { 272 | env::set_var("NO_PROXY", "*"); 273 | 274 | let s = ProxySettings::from_env(); 275 | 276 | assert!(s.disable_proxies); 277 | }); 278 | } 279 | 280 | #[test] 281 | fn test_proxy_from_env_no_proxy_root_domain() { 282 | with_reset_proxy_vars(|| { 283 | env::set_var("NO_PROXY", ".myroot.com"); 284 | 285 | let s = ProxySettings::from_env(); 286 | 287 | let url = Url::parse("https://mysub.myroot.com").unwrap(); 288 | assert!(s.for_url(&url).is_none()); 289 | assert_eq!(s.no_proxy_hosts[0], "myroot.com"); 290 | }); 291 | } 292 | 293 | #[test] 294 | fn test_proxy_from_env_no_proxy() { 295 | with_reset_proxy_vars(|| { 296 | env::set_var("NO_PROXY", "example.com, www.reddit.com, google.ca "); 297 | 298 | let s = ProxySettings::from_env(); 299 | 300 | assert_eq!(s.no_proxy_hosts, vec!["example.com", "www.reddit.com", "google.ca"]); 301 | }); 302 | } 303 | -------------------------------------------------------------------------------- /src/request/session.rs: -------------------------------------------------------------------------------- 1 | use std::convert::TryInto; 2 | use std::sync::Arc; 3 | use std::time::Duration; 4 | 5 | use http::header::{HeaderValue, IntoHeaderName}; 6 | use http::Method; 7 | 8 | #[cfg(feature = "charsets")] 9 | use crate::charsets::Charset; 10 | use crate::error::{Error, Result}; 11 | use crate::request::proxy::ProxySettings; 12 | use crate::request::{BaseSettings, RequestBuilder}; 13 | use crate::tls::Certificate; 14 | 15 | /// `Session` is a type that can carry settings over multiple requests. The settings applied to the 16 | /// `Session` are applied to every request created from this `Session`. 17 | /// 18 | /// `Session` can be cloned cheaply and sent to other threads as it uses [std::sync::Arc] internally. 19 | #[derive(Clone, Debug, Default)] 20 | pub struct Session { 21 | base_settings: Arc, 22 | } 23 | 24 | impl Session { 25 | /// Create a new `Session` with default settings. 26 | pub fn new() -> Session { 27 | Session { 28 | base_settings: Arc::new(BaseSettings::default()), 29 | } 30 | } 31 | 32 | /// Create a new `RequestBuilder` with the GET method and this Session's settings applied on it. 33 | pub fn get(&self, base_url: U) -> RequestBuilder 34 | where 35 | U: AsRef, 36 | { 37 | RequestBuilder::with_settings(Method::GET, base_url, self.base_settings.clone()) 38 | } 39 | 40 | /// Create a new `RequestBuilder` with the POST method and this Session's settings applied on it. 41 | pub fn post(&self, base_url: U) -> RequestBuilder 42 | where 43 | U: AsRef, 44 | { 45 | RequestBuilder::with_settings(Method::POST, base_url, self.base_settings.clone()) 46 | } 47 | 48 | /// Create a new `RequestBuilder` with the PUT method and this Session's settings applied on it. 49 | pub fn put(&self, base_url: U) -> RequestBuilder 50 | where 51 | U: AsRef, 52 | { 53 | RequestBuilder::with_settings(Method::PUT, base_url, self.base_settings.clone()) 54 | } 55 | 56 | /// Create a new `RequestBuilder` with the DELETE method and this Session's settings applied on it. 57 | pub fn delete(&self, base_url: U) -> RequestBuilder 58 | where 59 | U: AsRef, 60 | { 61 | RequestBuilder::with_settings(Method::DELETE, base_url, self.base_settings.clone()) 62 | } 63 | 64 | /// Create a new `RequestBuilder` with the HEAD method and this Session's settings applied on it. 65 | pub fn head(&self, base_url: U) -> RequestBuilder 66 | where 67 | U: AsRef, 68 | { 69 | RequestBuilder::with_settings(Method::HEAD, base_url, self.base_settings.clone()) 70 | } 71 | 72 | /// Create a new `RequestBuilder` with the OPTIONS method and this Session's settings applied on it. 73 | pub fn options(&self, base_url: U) -> RequestBuilder 74 | where 75 | U: AsRef, 76 | { 77 | RequestBuilder::with_settings(Method::OPTIONS, base_url, self.base_settings.clone()) 78 | } 79 | 80 | /// Create a new `RequestBuilder` with the PATCH method and this Session's settings applied on it. 81 | pub fn patch(&self, base_url: U) -> RequestBuilder 82 | where 83 | U: AsRef, 84 | { 85 | RequestBuilder::with_settings(Method::PATCH, base_url, self.base_settings.clone()) 86 | } 87 | 88 | /// Create a new `RequestBuilder` with the TRACE method and this Session's settings applied on it. 89 | pub fn trace(&self, base_url: U) -> RequestBuilder 90 | where 91 | U: AsRef, 92 | { 93 | RequestBuilder::with_settings(Method::TRACE, base_url, self.base_settings.clone()) 94 | } 95 | 96 | // 97 | // Settings 98 | // 99 | 100 | /// Modify a header for this `Session`. 101 | /// 102 | /// If the header is already present, the value will be replaced. If you wish to append a new header, 103 | /// use `header_append`. 104 | /// 105 | /// # Panics 106 | /// This method will panic if the value is invalid. 107 | pub fn header(&mut self, header: H, value: V) 108 | where 109 | H: IntoHeaderName, 110 | V: TryInto, 111 | Error: From, 112 | { 113 | self.try_header(header, value).expect("invalid header value"); 114 | } 115 | 116 | /// Append a new header for this `Session`. 117 | /// 118 | /// The new header is always appended to the headers, even if the header already exists. 119 | /// 120 | /// # Panics 121 | /// This method will panic if the value is invalid. 122 | pub fn header_append(&mut self, header: H, value: V) 123 | where 124 | H: IntoHeaderName, 125 | V: TryInto, 126 | Error: From, 127 | { 128 | self.try_header_append(header, value).expect("invalid header value"); 129 | } 130 | 131 | /// Modify a header for this `Session`. 132 | /// 133 | /// If the header is already present, the value will be replaced. If you wish to append a new header, 134 | /// use `header_append`. 135 | pub fn try_header(&mut self, header: H, value: V) -> Result<()> 136 | where 137 | H: IntoHeaderName, 138 | V: TryInto, 139 | Error: From, 140 | { 141 | self.base_settings.try_header(header, value) 142 | } 143 | 144 | /// Append a new header to this `Session`. 145 | /// 146 | /// The new header is always appended to the headers, even if the header already exists. 147 | pub fn try_header_append(&mut self, header: H, value: V) -> Result<()> 148 | where 149 | H: IntoHeaderName, 150 | V: TryInto, 151 | Error: From, 152 | { 153 | self.base_settings.try_header_append(header, value) 154 | } 155 | 156 | /// Set the maximum number of headers accepted in responses to requests created from this `Session`. 157 | /// 158 | /// The default is 100. 159 | pub fn max_headers(&mut self, max_headers: usize) { 160 | self.base_settings.set_max_headers(max_headers); 161 | } 162 | 163 | /// Set the maximum number of redirections the requests created from this `Session` can perform. 164 | /// 165 | /// The default is 5. 166 | pub fn max_redirections(&mut self, max_redirections: u32) { 167 | self.base_settings.set_max_redirections(max_redirections); 168 | } 169 | 170 | /// Sets if requests created from this `Session` should follow redirects, 3xx codes. 171 | /// 172 | /// This value defaults to true. 173 | pub fn follow_redirects(&mut self, follow_redirects: bool) { 174 | self.base_settings.set_follow_redirects(follow_redirects); 175 | } 176 | 177 | /// Sets a connect timeout for requests created from this `Session`. 178 | /// 179 | /// The default is 30 seconds. 180 | pub fn connect_timeout(&mut self, connect_timeout: Duration) { 181 | self.base_settings.set_connect_timeout(connect_timeout); 182 | } 183 | 184 | /// Sets a read timeout for requests created from this `Session`. 185 | /// 186 | /// The default is 30 seconds. 187 | pub fn read_timeout(&mut self, read_timeout: Duration) { 188 | self.base_settings.set_read_tmeout(read_timeout); 189 | } 190 | 191 | /// Sets a timeout for the maximum duration of requests created from this `Session`. 192 | /// 193 | /// Applies after a TCP connection is established. Defaults to no timeout. 194 | pub fn timeout(&mut self, timeout: Duration) { 195 | self.base_settings.set_timeout(Some(timeout)); 196 | } 197 | 198 | /// Sets the proxy settigns for requests created from this `Session`. 199 | /// 200 | /// If left untouched, the defaults are to use system proxy settings found in environment variables. 201 | pub fn proxy_settings(&mut self, proxy_settings: ProxySettings) { 202 | self.base_settings.set_proxy_settings(proxy_settings); 203 | } 204 | 205 | /// Set the default charset to use while parsing the responses of requests created from this `Session`. 206 | /// 207 | /// If the response does not say which charset it uses, this charset will be used to decode the requests. 208 | /// This value defaults to `None`, in which case ISO-8859-1 is used. 209 | #[cfg(feature = "charsets")] 210 | pub fn default_charset(&mut self, default_charset: Option) { 211 | self.base_settings.set_default_charset(default_charset); 212 | } 213 | 214 | /// Sets if requests created from this `Session` will announce that they accept compression. 215 | /// 216 | /// This value defaults to true. Note that this only lets the browser know that the requests support 217 | /// compression, the server might choose not to compress the content. 218 | #[cfg(feature = "flate2")] 219 | pub fn allow_compression(&mut self, allow_compression: bool) { 220 | self.base_settings.set_allow_compression(allow_compression); 221 | } 222 | 223 | /// Sets if requests created from this `Session` will accept invalid TLS certificates. 224 | /// 225 | /// Accepting invalid certificates implies that invalid hostnames are accepted 226 | /// as well. 227 | /// 228 | /// The default value is `false`. 229 | /// 230 | /// # Danger 231 | /// Use this setting with care. This will accept **any** TLS certificate valid or not. 232 | /// If you are using self signed certificates, it is much safer to add their root CA 233 | /// to the list of trusted root CAs by your system. 234 | pub fn danger_accept_invalid_certs(&mut self, accept_invalid_certs: bool) { 235 | self.base_settings.set_accept_invalid_certs(accept_invalid_certs); 236 | } 237 | 238 | /// Sets if requests created from this `Session` will accept an invalid hostname in a TLS certificate. 239 | /// 240 | /// The default value is `false`. 241 | /// 242 | /// # Danger 243 | /// Use this setting with care. This will accept TLS certificates that do not match 244 | /// the hostname. 245 | pub fn danger_accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) { 246 | self.base_settings 247 | .set_accept_invalid_hostnames(accept_invalid_hostnames); 248 | } 249 | 250 | /// Adds a root certificate that will be trusted by requests created from this `Session`. 251 | pub fn add_root_certificate(&mut self, cert: Certificate) { 252 | self.base_settings.add_root_certificate(cert); 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /src/request/settings.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | use std::time::Duration; 3 | 4 | use http::header::IntoHeaderName; 5 | use http::{HeaderMap, HeaderValue}; 6 | 7 | #[cfg(feature = "charsets")] 8 | use crate::charsets::Charset; 9 | use crate::error::{Error, Result}; 10 | use crate::request::proxy::ProxySettings; 11 | use crate::skip_debug::SkipDebug; 12 | use crate::tls::Certificate; 13 | 14 | use super::{header_append, header_insert}; 15 | 16 | #[derive(Clone, Debug)] 17 | pub struct BaseSettings { 18 | pub headers: HeaderMap, 19 | pub root_certificates: SkipDebug>, 20 | pub max_headers: usize, 21 | pub max_redirections: u32, 22 | pub follow_redirects: bool, 23 | pub connect_timeout: Duration, 24 | pub read_timeout: Duration, 25 | pub timeout: Option, 26 | pub proxy_settings: ProxySettings, 27 | pub accept_invalid_certs: bool, 28 | pub accept_invalid_hostnames: bool, 29 | #[cfg(feature = "charsets")] 30 | pub default_charset: Option, 31 | #[cfg(feature = "flate2")] 32 | pub allow_compression: bool, 33 | } 34 | 35 | impl Default for BaseSettings { 36 | fn default() -> Self { 37 | BaseSettings { 38 | headers: HeaderMap::new(), 39 | max_headers: 100, 40 | max_redirections: 5, 41 | follow_redirects: true, 42 | connect_timeout: Duration::from_secs(30), 43 | read_timeout: Duration::from_secs(30), 44 | timeout: None, 45 | proxy_settings: ProxySettings::from_env(), 46 | accept_invalid_certs: false, 47 | accept_invalid_hostnames: false, 48 | root_certificates: SkipDebug(Vec::new()), 49 | 50 | #[cfg(feature = "charsets")] 51 | default_charset: None, 52 | #[cfg(feature = "flate2")] 53 | allow_compression: true, 54 | } 55 | } 56 | } 57 | 58 | macro_rules! basic_setter { 59 | ($name:ident, $param:ident, $type:ty) => { 60 | #[inline] 61 | pub(crate) fn $name(self: &mut Arc, $param: $type) { 62 | Arc::make_mut(self).$param = $param; 63 | } 64 | }; 65 | } 66 | 67 | impl BaseSettings { 68 | #[inline] 69 | fn headers_mut(self: &mut Arc) -> &mut HeaderMap { 70 | &mut Arc::make_mut(self).headers 71 | } 72 | 73 | #[inline] 74 | pub(crate) fn try_header(self: &mut Arc, header: H, value: V) -> Result<()> 75 | where 76 | H: IntoHeaderName, 77 | V: TryInto, 78 | Error: From, 79 | { 80 | header_insert(self.headers_mut(), header, value) 81 | } 82 | 83 | #[inline] 84 | pub(crate) fn try_header_append(self: &mut Arc, header: H, value: V) -> Result<()> 85 | where 86 | H: IntoHeaderName, 87 | V: TryInto, 88 | Error: From, 89 | { 90 | header_append(self.headers_mut(), header, value) 91 | } 92 | 93 | #[inline] 94 | pub(crate) fn add_root_certificate(self: &mut Arc, cert: Certificate) { 95 | Arc::make_mut(self).root_certificates.0.push(cert); 96 | } 97 | 98 | basic_setter!(set_max_headers, max_headers, usize); 99 | basic_setter!(set_max_redirections, max_redirections, u32); 100 | basic_setter!(set_follow_redirects, follow_redirects, bool); 101 | basic_setter!(set_connect_timeout, connect_timeout, Duration); 102 | basic_setter!(set_read_tmeout, read_timeout, Duration); 103 | basic_setter!(set_timeout, timeout, Option); 104 | basic_setter!(set_proxy_settings, proxy_settings, ProxySettings); 105 | basic_setter!(set_accept_invalid_certs, accept_invalid_certs, bool); 106 | basic_setter!(set_accept_invalid_hostnames, accept_invalid_hostnames, bool); 107 | #[cfg(feature = "charsets")] 108 | basic_setter!(set_default_charset, default_charset, Option); 109 | #[cfg(feature = "flate2")] 110 | basic_setter!(set_allow_compression, allow_compression, bool); 111 | } 112 | -------------------------------------------------------------------------------- /src/streams.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | use std::io::Cursor; 3 | use std::io::{self, Read, Write}; 4 | #[cfg(not(windows))] 5 | use std::net::Shutdown; 6 | use std::net::TcpStream; 7 | #[cfg(windows)] 8 | use std::os::{ 9 | raw::c_int, 10 | windows::{io::AsRawSocket, raw::SOCKET}, 11 | }; 12 | use std::sync::mpsc; 13 | use std::thread; 14 | use std::time::Instant; 15 | 16 | use base64::Engine; 17 | 18 | use url::{Host, Url}; 19 | 20 | use crate::happy; 21 | use crate::parsing::buffers::BufReaderWrite; 22 | use crate::parsing::response::parse_response_head; 23 | use crate::request::BaseSettings; 24 | use crate::tls::{TlsHandshaker, TlsStream}; 25 | use crate::{ErrorKind, Result}; 26 | 27 | pub struct ConnectInfo<'a> { 28 | pub url: &'a Url, 29 | pub proxy: Option<&'a Url>, 30 | pub base_settings: &'a BaseSettings, 31 | pub deadline: Option, 32 | } 33 | 34 | #[allow(clippy::large_enum_variant)] 35 | #[derive(Debug)] 36 | pub enum BaseStream { 37 | Plain { 38 | stream: TcpStream, 39 | timeout: Option>, 40 | }, 41 | Tls { 42 | stream: TlsStream, 43 | timeout: Option>, 44 | }, 45 | Tunnel { 46 | stream: Box>>, 47 | }, 48 | #[cfg(test)] 49 | Mock(Cursor>), 50 | } 51 | 52 | impl BaseStream { 53 | pub fn connect(info: &ConnectInfo) -> Result { 54 | let connect_url = info.proxy.unwrap_or(info.url); 55 | 56 | let host = connect_url.host().ok_or(ErrorKind::InvalidUrlHost)?; 57 | let port = connect_url.port_or_known_default().ok_or(ErrorKind::InvalidUrlPort)?; 58 | 59 | debug!("trying to connect to {}:{}", host, port); 60 | 61 | let stream = match connect_url.scheme() { 62 | "http" => BaseStream::connect_tcp(&host, port, info) 63 | .map(|(stream, timeout)| BaseStream::Plain { stream, timeout }), 64 | "https" => BaseStream::connect_tls(&host, port, info), 65 | _ => Err(ErrorKind::InvalidBaseUrl.into()), 66 | }?; 67 | 68 | if let Some(proxy_url) = info.proxy { 69 | if info.url.scheme() == "https" { 70 | return BaseStream::initiate_tunnel(stream, proxy_url, info.url, info.base_settings); 71 | } 72 | } 73 | 74 | Ok(stream) 75 | } 76 | 77 | fn initiate_tunnel( 78 | mut stream: BaseStream, 79 | proxy_url: &Url, 80 | remote_url: &Url, 81 | base_settings: &BaseSettings, 82 | ) -> Result { 83 | let remote_host = remote_url.host_str().ok_or(ErrorKind::InvalidUrlHost)?; 84 | let remote_port = remote_url.port_or_known_default().ok_or(ErrorKind::InvalidUrlPort)?; 85 | let proxy_host = proxy_url.host_str().ok_or(ErrorKind::InvalidUrlHost)?; 86 | let proxy_port = proxy_url.port_or_known_default().ok_or(ErrorKind::InvalidUrlPort)?; 87 | 88 | debug!( 89 | "tunnelling to {}:{} via {}:{}", 90 | remote_host, remote_port, proxy_host, proxy_port, 91 | ); 92 | 93 | write!(stream, "CONNECT {remote_host}:{remote_port} HTTP/1.1\r\n")?; 94 | write!(stream, "Host: {proxy_host}:{proxy_port}\r\n")?; 95 | write!(stream, "Connection: close\r\n")?; 96 | if proxy_url.has_authority() { 97 | let username = proxy_url.username(); 98 | let auth = match proxy_url.password() { 99 | Some(password) => format!("{username}:{password}"), 100 | None => format!("{username}:"), 101 | }; 102 | let basic_auth = base64::engine::general_purpose::STANDARD.encode(auth); 103 | write!(stream, "Proxy-Authorization: Basic {basic_auth}\r\n")?; 104 | } 105 | write!(stream, "\r\n")?; 106 | 107 | let mut stream = BufReaderWrite::new(stream); 108 | let (status, _) = parse_response_head(&mut stream, base_settings.max_headers)?; 109 | 110 | debug!("tunnel response status code is {}", status); 111 | 112 | if !status.is_success() { 113 | // Error initializaing tunnel, get status code and up to 10 KiB of data from the body. 114 | let mut buf = Vec::with_capacity(2048); 115 | stream.take(10 * 1024).read_to_end(&mut buf)?; 116 | let err = ErrorKind::ConnectError { 117 | status_code: status, 118 | body: buf, 119 | }; 120 | return Err(err.into()); 121 | } 122 | 123 | let mut handshaker = TlsHandshaker::new(); 124 | apply_base_settings(&mut handshaker, base_settings); 125 | let stream = handshaker.handshake(remote_host, stream)?; 126 | 127 | Ok(BaseStream::Tunnel { 128 | stream: Box::new(stream), 129 | }) 130 | } 131 | 132 | fn connect_tcp(host: &Host<&str>, port: u16, info: &ConnectInfo) -> Result<(TcpStream, Option>)> { 133 | let stream = happy::connect(host, port, info.base_settings.connect_timeout, info.deadline)?; 134 | stream.set_read_timeout(Some(info.base_settings.read_timeout))?; 135 | let timeout = info 136 | .deadline 137 | .map(|deadline| -> Result> { 138 | #[cfg(not(windows))] 139 | let stream = stream.try_clone()?; 140 | #[cfg(windows)] 141 | let socket = stream.as_raw_socket(); 142 | 143 | let (tx, rx) = mpsc::channel(); 144 | thread::spawn(move || { 145 | let shutdown = match deadline.checked_duration_since(Instant::now()) { 146 | Some(timeout) => rx.recv_timeout(timeout) == Err(mpsc::RecvTimeoutError::Timeout), 147 | None => rx.try_recv() == Err(mpsc::TryRecvError::Empty), 148 | }; 149 | 150 | if shutdown { 151 | drop(rx); 152 | 153 | #[cfg(not(windows))] 154 | let _ = stream.shutdown(Shutdown::Both); 155 | 156 | #[cfg(windows)] 157 | extern "system" { 158 | fn closesocket(socket: SOCKET) -> c_int; 159 | } 160 | 161 | #[cfg(windows)] 162 | unsafe { 163 | closesocket(socket); 164 | } 165 | } 166 | }); 167 | Ok(tx) 168 | }) 169 | .transpose()?; 170 | Ok((stream, timeout)) 171 | } 172 | 173 | fn connect_tls(host: &Host<&str>, port: u16, info: &ConnectInfo) -> Result { 174 | let (stream, timeout) = BaseStream::connect_tcp(host, port, info)?; 175 | let mut handshaker = TlsHandshaker::new(); 176 | apply_base_settings(&mut handshaker, info.base_settings); 177 | let stream = handshaker.handshake(&host.to_string(), stream)?; 178 | Ok(BaseStream::Tls { stream, timeout }) 179 | } 180 | 181 | #[cfg(test)] 182 | pub fn mock(bytes: Vec) -> BaseStream { 183 | BaseStream::Mock(Cursor::new(bytes)) 184 | } 185 | } 186 | 187 | impl Read for BaseStream { 188 | #[inline] 189 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 190 | match self { 191 | BaseStream::Plain { stream, timeout } => read_timeout(stream, buf, timeout), 192 | BaseStream::Tls { stream, timeout } => read_timeout(stream, buf, timeout), 193 | BaseStream::Tunnel { stream } => stream.read(buf), 194 | #[cfg(test)] 195 | BaseStream::Mock(s) => s.read(buf), 196 | } 197 | } 198 | } 199 | 200 | impl Write for BaseStream { 201 | #[inline] 202 | fn write(&mut self, buf: &[u8]) -> io::Result { 203 | match self { 204 | BaseStream::Plain { stream, .. } => stream.write(buf), 205 | BaseStream::Tls { stream, .. } => stream.write(buf), 206 | BaseStream::Tunnel { stream } => stream.write(buf), 207 | #[cfg(test)] 208 | _ => Ok(0), 209 | } 210 | } 211 | 212 | #[inline] 213 | fn flush(&mut self) -> io::Result<()> { 214 | match self { 215 | BaseStream::Plain { stream, .. } => stream.flush(), 216 | BaseStream::Tls { stream, .. } => stream.flush(), 217 | BaseStream::Tunnel { stream } => stream.flush(), 218 | #[cfg(test)] 219 | _ => Ok(()), 220 | } 221 | } 222 | } 223 | 224 | fn read_timeout(stream: &mut impl Read, buf: &mut [u8], timeout: &Option>) -> io::Result { 225 | match stream.read(buf) { 226 | Ok(0) => { 227 | #[cfg(unix)] 228 | if let Some(timeout) = timeout { 229 | // On Unix we get a 0 read when the connection is shutdown by the timeout thread. 230 | if !buf.is_empty() && timeout.send(()).is_err() { 231 | return Err(io::ErrorKind::TimedOut.into()); 232 | } 233 | } 234 | Ok(0) 235 | } 236 | Ok(read) => Ok(read), 237 | Err(err) => { 238 | #[cfg(windows)] 239 | if let Some(timeout) = timeout { 240 | // On Windows we get a ConnectionAborted when the connection is shutdown by the timeout thread. 241 | if err.kind() == io::ErrorKind::ConnectionAborted && timeout.send(()).is_err() { 242 | return Err(io::ErrorKind::TimedOut.into()); 243 | } 244 | } 245 | Err(err) 246 | } 247 | } 248 | } 249 | 250 | fn apply_base_settings(handshaker: &mut TlsHandshaker, base_settings: &BaseSettings) { 251 | handshaker.danger_accept_invalid_certs(base_settings.accept_invalid_certs); 252 | handshaker.danger_accept_invalid_hostnames(base_settings.accept_invalid_hostnames); 253 | for cert in &base_settings.root_certificates.0 { 254 | handshaker.add_root_certificate(cert.clone()); 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /src/tls/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "tls-native")] 2 | mod native_tls_impl; 3 | 4 | #[cfg(all(feature = "__rustls", not(feature = "tls-native")))] 5 | mod rustls_impl; 6 | 7 | #[cfg(all(not(feature = "tls-native"), not(feature = "__rustls")))] 8 | mod no_tls_impl; 9 | 10 | #[cfg(feature = "tls-native")] 11 | pub use native_tls_impl::*; 12 | 13 | #[cfg(all(feature = "__rustls", not(feature = "tls-native")))] 14 | pub use rustls_impl::*; 15 | 16 | #[cfg(all(not(feature = "tls-native"), not(feature = "__rustls")))] 17 | pub use no_tls_impl::*; 18 | -------------------------------------------------------------------------------- /src/tls/native_tls_impl.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::io; 3 | use std::io::prelude::*; 4 | 5 | use native_tls::HandshakeError; 6 | 7 | use crate::Result; 8 | 9 | pub type Certificate = native_tls::Certificate; 10 | 11 | pub struct TlsHandshaker { 12 | inner: native_tls::TlsConnectorBuilder, 13 | } 14 | 15 | impl TlsHandshaker { 16 | pub fn new() -> TlsHandshaker { 17 | TlsHandshaker { 18 | inner: native_tls::TlsConnector::builder(), 19 | } 20 | } 21 | 22 | pub fn danger_accept_invalid_certs(&mut self, accept_invalid_certs: bool) { 23 | self.inner.danger_accept_invalid_certs(accept_invalid_certs); 24 | } 25 | 26 | pub fn danger_accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) { 27 | self.inner.danger_accept_invalid_hostnames(accept_invalid_hostnames); 28 | } 29 | 30 | pub fn add_root_certificate(&mut self, cert: Certificate) { 31 | self.inner.add_root_certificate(cert); 32 | } 33 | 34 | pub fn handshake(&self, domain: &str, stream: S) -> Result> 35 | where 36 | S: Read + Write, 37 | { 38 | let connector = self.inner.build()?; 39 | let stream = match connector.connect(domain, stream) { 40 | Ok(stream) => stream, 41 | Err(HandshakeError::Failure(err)) => return Err(err.into()), 42 | Err(HandshakeError::WouldBlock(mut stream)) => loop { 43 | match stream.handshake() { 44 | Ok(stream) => break stream, 45 | Err(HandshakeError::Failure(err)) => return Err(err.into()), 46 | Err(HandshakeError::WouldBlock(mid_stream)) => stream = mid_stream, 47 | } 48 | }, 49 | }; 50 | Ok(TlsStream { inner: stream }) 51 | } 52 | } 53 | 54 | pub struct TlsStream 55 | where 56 | S: Read + Write, 57 | { 58 | inner: native_tls::TlsStream, 59 | } 60 | 61 | impl Read for TlsStream 62 | where 63 | S: Read + Write, 64 | { 65 | #[inline] 66 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 67 | self.inner.read(buf) 68 | } 69 | } 70 | 71 | impl Write for TlsStream 72 | where 73 | S: Read + Write, 74 | { 75 | #[inline] 76 | fn write(&mut self, buf: &[u8]) -> io::Result { 77 | self.inner.write(buf) 78 | } 79 | 80 | #[inline] 81 | fn flush(&mut self) -> io::Result<()> { 82 | self.inner.flush() 83 | } 84 | } 85 | 86 | impl fmt::Debug for TlsStream 87 | where 88 | S: Read + Write, 89 | { 90 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 91 | write!(f, "TlsStream[native_tls]") 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/tls/no_tls_impl.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::io; 3 | use std::io::prelude::*; 4 | use std::marker::PhantomData; 5 | 6 | use crate::{ErrorKind, Result}; 7 | 8 | pub type Certificate = (); 9 | 10 | pub struct TlsHandshaker {} 11 | 12 | impl TlsHandshaker { 13 | pub fn new() -> TlsHandshaker { 14 | TlsHandshaker {} 15 | } 16 | 17 | pub fn danger_accept_invalid_certs(&mut self, _accept_invalid_certs: bool) {} 18 | 19 | pub fn danger_accept_invalid_hostnames(&mut self, _accept_invalid_hostnames: bool) {} 20 | 21 | pub fn add_root_certificate(&mut self, _cert: Certificate) {} 22 | 23 | pub fn handshake(&self, _domain: &str, _stream: S) -> Result> 24 | where 25 | S: Read + Write, 26 | { 27 | Err(ErrorKind::TlsDisabled.into()) 28 | } 29 | } 30 | 31 | pub struct TlsStream 32 | where 33 | S: Read + Write, 34 | { 35 | dummy: PhantomData, 36 | } 37 | 38 | impl Read for TlsStream 39 | where 40 | S: Read + Write, 41 | { 42 | #[inline] 43 | fn read(&mut self, _buf: &mut [u8]) -> io::Result { 44 | Ok(0) 45 | } 46 | } 47 | 48 | impl Write for TlsStream 49 | where 50 | S: Read + Write, 51 | { 52 | #[inline] 53 | fn write(&mut self, _buf: &[u8]) -> io::Result { 54 | Ok(0) 55 | } 56 | 57 | #[inline] 58 | fn flush(&mut self) -> io::Result<()> { 59 | Ok(()) 60 | } 61 | } 62 | 63 | impl fmt::Debug for TlsStream 64 | where 65 | S: Read + Write, 66 | { 67 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 68 | write!(f, "TlsStream[no_tls]") 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/tls/rustls_impl.rs: -------------------------------------------------------------------------------- 1 | use std::convert::TryFrom; 2 | use std::fmt; 3 | use std::io; 4 | use std::io::prelude::*; 5 | use std::sync::Arc; 6 | 7 | use rustls::{ 8 | client::{ 9 | danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, 10 | WebPkiServerVerifier, 11 | }, 12 | pki_types::{CertificateDer, ServerName, UnixTime}, 13 | ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, SignatureScheme, StreamOwned, 14 | }; 15 | #[cfg(feature = "tls-rustls-native-roots")] 16 | use rustls_native_certs::load_native_certs; 17 | #[cfg(feature = "tls-rustls-webpki-roots")] 18 | use webpki_roots::TLS_SERVER_ROOTS; 19 | 20 | use crate::{Error, ErrorKind, Result}; 21 | 22 | pub type Certificate = CertificateDer<'static>; 23 | 24 | pub struct TlsHandshaker { 25 | inner: Option>, 26 | accept_invalid_certs: bool, 27 | accept_invalid_hostnames: bool, 28 | additional_certs: Vec, 29 | } 30 | 31 | impl TlsHandshaker { 32 | pub fn new() -> TlsHandshaker { 33 | TlsHandshaker { 34 | inner: None, 35 | accept_invalid_hostnames: false, 36 | accept_invalid_certs: false, 37 | additional_certs: Vec::new(), 38 | } 39 | } 40 | 41 | pub fn danger_accept_invalid_certs(&mut self, accept_invalid_certs: bool) { 42 | self.accept_invalid_certs = accept_invalid_certs; 43 | self.inner = None; 44 | } 45 | 46 | pub fn danger_accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) { 47 | self.accept_invalid_hostnames = accept_invalid_hostnames; 48 | self.inner = None; 49 | } 50 | 51 | pub fn add_root_certificate(&mut self, cert: Certificate) { 52 | self.additional_certs.push(cert); 53 | self.inner = None; 54 | } 55 | 56 | fn client_config(&mut self) -> Result> { 57 | match &self.inner { 58 | Some(inner) => Ok(Arc::clone(inner)), 59 | None => { 60 | let mut root_store = RootCertStore::empty(); 61 | 62 | #[cfg(feature = "tls-rustls-webpki-roots")] 63 | root_store.extend(TLS_SERVER_ROOTS.iter().cloned()); 64 | 65 | #[cfg(feature = "tls-rustls-native-roots")] 66 | for cert in load_native_certs().certs { 67 | // Inspired by https://github.com/seanmonstar/reqwest/blob/231b18f83572836c674404b33cb1ca8b35ca3e36/src/async_impl/client.rs#L363-L365 68 | // Native certificate stores often include certificates with invalid formats, 69 | // but we don't want those invalid entries to invalidate the entire process of 70 | // loading native root certificates 71 | if let Err(e) = root_store.add(cert) { 72 | warn!("Could not load native root certificate: {}", e); 73 | } 74 | } 75 | 76 | for cert in self.additional_certs.iter().cloned() { 77 | root_store.add(cert)?; 78 | } 79 | 80 | let config = ClientConfig::builder() 81 | .dangerous() 82 | .with_custom_certificate_verifier(Arc::new(CustomCertVerifier { 83 | upstream: WebPkiServerVerifier::builder(root_store.into()).build()?, 84 | accept_invalid_certs: self.accept_invalid_certs, 85 | accept_invalid_hostnames: self.accept_invalid_hostnames, 86 | })) 87 | .with_no_client_auth() 88 | .into(); 89 | 90 | self.inner = Some(Arc::clone(&config)); 91 | 92 | Ok(config) 93 | } 94 | } 95 | } 96 | 97 | pub fn handshake(&mut self, domain: &str, mut stream: S) -> Result> 98 | where 99 | S: Read + Write, 100 | { 101 | let domain = ServerName::try_from(domain) 102 | .map_err(|_| Error(Box::new(ErrorKind::InvalidDNSName(domain.to_owned()))))? 103 | .to_owned(); 104 | let config = self.client_config()?; 105 | let mut session = ClientConnection::new(config, domain)?; 106 | 107 | while let Err(err) = session.complete_io(&mut stream) { 108 | if err.kind() != io::ErrorKind::WouldBlock || !session.is_handshaking() { 109 | return Err(err.into()); 110 | } 111 | } 112 | 113 | Ok(TlsStream { 114 | inner: StreamOwned::new(session, stream), 115 | }) 116 | } 117 | } 118 | 119 | pub struct TlsStream 120 | where 121 | S: Read + Write, 122 | { 123 | inner: StreamOwned, 124 | } 125 | 126 | impl TlsStream 127 | where 128 | S: Read + Write, 129 | { 130 | fn handle_close_notify(&mut self, res: io::Result) -> io::Result { 131 | match res { 132 | Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => { 133 | self.inner.conn.send_close_notify(); 134 | self.inner.conn.complete_io(&mut self.inner.sock)?; 135 | 136 | Ok(0) 137 | } 138 | Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => { 139 | // In some cases the server does not terminate the connection cleanly 140 | // We just turn that error into EOF. 141 | Ok(0) 142 | } 143 | res => res, 144 | } 145 | } 146 | } 147 | 148 | impl Read for TlsStream 149 | where 150 | S: Read + Write, 151 | { 152 | #[inline] 153 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 154 | let res = self.inner.read(buf); 155 | self.handle_close_notify(res) 156 | } 157 | } 158 | 159 | impl Write for TlsStream 160 | where 161 | S: Read + Write, 162 | { 163 | #[inline] 164 | fn write(&mut self, buf: &[u8]) -> io::Result { 165 | self.inner.write(buf) 166 | } 167 | 168 | #[inline] 169 | fn flush(&mut self) -> io::Result<()> { 170 | self.inner.flush() 171 | } 172 | } 173 | 174 | impl fmt::Debug for TlsStream 175 | where 176 | S: Read + Write, 177 | { 178 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 179 | write!(f, "TlsStream[rustls]") 180 | } 181 | } 182 | 183 | struct CustomCertVerifier { 184 | upstream: Arc, 185 | accept_invalid_certs: bool, 186 | accept_invalid_hostnames: bool, 187 | } 188 | 189 | impl fmt::Debug for CustomCertVerifier { 190 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 191 | f.debug_struct("CustomCertVerifier").finish() 192 | } 193 | } 194 | 195 | impl ServerCertVerifier for CustomCertVerifier { 196 | fn verify_server_cert( 197 | &self, 198 | end_entity: &CertificateDer, 199 | intermediates: &[CertificateDer], 200 | server_name: &ServerName, 201 | ocsp_response: &[u8], 202 | now: UnixTime, 203 | ) -> std::result::Result { 204 | match self 205 | .upstream 206 | .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now) 207 | { 208 | Err(rustls::Error::NoCertificatesPresented | rustls::Error::InvalidCertificate(_)) 209 | if self.accept_invalid_certs => 210 | { 211 | Ok(ServerCertVerified::assertion()) 212 | } 213 | 214 | Err(rustls::Error::InvalidCertificate(rustls::CertificateError::NotValidForName)) 215 | if self.accept_invalid_hostnames => 216 | { 217 | Ok(ServerCertVerified::assertion()) 218 | } 219 | 220 | upstream => upstream, 221 | } 222 | } 223 | 224 | fn verify_tls12_signature( 225 | &self, 226 | message: &[u8], 227 | cert: &CertificateDer<'_>, 228 | dss: &DigitallySignedStruct, 229 | ) -> std::result::Result { 230 | self.upstream.verify_tls12_signature(message, cert, dss) 231 | } 232 | 233 | fn verify_tls13_signature( 234 | &self, 235 | message: &[u8], 236 | cert: &CertificateDer<'_>, 237 | dss: &DigitallySignedStruct, 238 | ) -> std::result::Result { 239 | self.upstream.verify_tls13_signature(message, cert, dss) 240 | } 241 | 242 | fn supported_verify_schemes(&self) -> Vec { 243 | self.upstream.supported_verify_schemes() 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /tests/test_invalid_certs.rs: -------------------------------------------------------------------------------- 1 | #[test] 2 | fn test_error_when_self_signed() { 3 | let res = attohttpc::get("https://self-signed.badssl.com/").send(); 4 | let err = res.err().unwrap(); 5 | println!("{err:?}"); 6 | match err.kind() { 7 | attohttpc::ErrorKind::Tls(_) => (), 8 | _ => panic!("wrong error returned!"), 9 | } 10 | } 11 | 12 | #[test] 13 | fn test_accept_invalid_certs_ok_when_self_signed() { 14 | let res = attohttpc::get("https://self-signed.badssl.com/") 15 | .danger_accept_invalid_certs(true) 16 | .send(); 17 | assert!(res.is_ok()); 18 | } 19 | 20 | #[test] 21 | fn test_accept_invalid_certs_ok_when_wrong_host() { 22 | let res = attohttpc::get("https://wrong-host.badssl.com/") 23 | .danger_accept_invalid_certs(true) 24 | .send(); 25 | assert!(res.is_ok()); 26 | } 27 | 28 | #[test] 29 | fn test_error_when_wrong_host() { 30 | let res = attohttpc::get("https://wrong.host.badssl.com/").send(); 31 | let err = res.err().unwrap(); 32 | match err.kind() { 33 | attohttpc::ErrorKind::Tls(_) => (), 34 | _ => panic!("wrong error returned!"), 35 | } 36 | } 37 | 38 | #[test] 39 | fn test_accept_invalid_hostnames_error_when_expired() { 40 | let res = attohttpc::get("https://expired.badssl.com/") 41 | .danger_accept_invalid_hostnames(true) 42 | .send(); 43 | let err = res.err().unwrap(); 44 | match err.kind() { 45 | attohttpc::ErrorKind::Tls(_) => (), 46 | _ => panic!("wrong error returned!"), 47 | } 48 | } 49 | 50 | #[test] 51 | fn test_accept_invalid_hostnames_ok_when_wrong_host() { 52 | let res = attohttpc::get("https://wrong.host.badssl.com/") 53 | .danger_accept_invalid_hostnames(true) 54 | .send(); 55 | assert!(res.is_ok()); 56 | } 57 | -------------------------------------------------------------------------------- /tests/test_multipart.rs: -------------------------------------------------------------------------------- 1 | use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; 2 | use std::time::Duration; 3 | 4 | use axum::extract::{DefaultBodyLimit, Multipart, State}; 5 | use axum::routing::post; 6 | use axum::Router; 7 | use bytes::Bytes; 8 | 9 | #[derive(Debug, PartialEq, Eq)] 10 | struct Part { 11 | name: Option, 12 | file_name: Option, 13 | content_type: Option, 14 | data: Bytes, 15 | } 16 | 17 | async fn start_server() -> (u16, Receiver>) { 18 | let (send, recv) = sync_channel(1); 19 | 20 | async fn accept_form(State(send): State>>, mut multipart: Multipart) -> &'static str { 21 | let mut parts = Vec::new(); 22 | while let Some(field) = multipart.next_field().await.unwrap() { 23 | parts.push(Part { 24 | name: field.name().map(|s| s.to_string()), 25 | file_name: field.file_name().map(|s| s.to_string()), 26 | content_type: field.content_type().map(|s| s.to_string()), 27 | data: field.bytes().await.unwrap(), 28 | }); 29 | } 30 | send.send(parts).unwrap(); 31 | "OK" 32 | } 33 | 34 | let app = Router::new() 35 | .route("/multipart", post(accept_form)) 36 | .layer(DefaultBodyLimit::disable()) 37 | .with_state(send); 38 | 39 | let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); 40 | let port = listener.local_addr().unwrap().port(); 41 | tokio::spawn(async move { 42 | axum::serve(listener, app).await.unwrap(); 43 | }); 44 | (port, recv) 45 | } 46 | 47 | #[tokio::test(flavor = "multi_thread")] 48 | async fn test_multipart_default() -> attohttpc::Result<()> { 49 | let file = attohttpc::MultipartFile::new("file", b"abc123") 50 | .with_type("text/plain")? 51 | .with_filename("hello.txt"); 52 | let form = attohttpc::MultipartBuilder::new() 53 | .with_text("Hello", "world!") 54 | .with_file(file) 55 | .build()?; 56 | 57 | let (port, recv) = start_server().await; 58 | 59 | attohttpc::post(format!("http://localhost:{port}/multipart")) 60 | .body(form) 61 | .send()? 62 | .text()?; 63 | 64 | let parts = recv.recv_timeout(Duration::from_secs(5)).unwrap(); 65 | assert_eq!(parts.len(), 2); 66 | assert_eq!( 67 | parts, 68 | vec![ 69 | Part { 70 | name: Some("Hello".to_string()), 71 | file_name: None, 72 | content_type: None, 73 | data: Bytes::from(&b"world!"[..]) 74 | }, 75 | Part { 76 | name: Some("file".to_string()), 77 | file_name: Some("hello.txt".to_string()), 78 | content_type: Some("text/plain".to_string()), 79 | data: Bytes::from(&b"abc123"[..]) 80 | } 81 | ] 82 | ); 83 | 84 | Ok(()) 85 | } 86 | -------------------------------------------------------------------------------- /tests/test_proxy.rs: -------------------------------------------------------------------------------- 1 | mod tools; 2 | 3 | use url::Url; 4 | 5 | #[tokio::test(flavor = "multi_thread")] 6 | async fn test_http_url_with_http_proxy() -> Result<(), anyhow::Error> { 7 | let remote_port = tools::start_hello_world_server(false).await?; 8 | let remote_url = format!("http://localhost:{remote_port}"); 9 | 10 | let proxy_port = tools::start_proxy_server(false).await?; 11 | let proxy_url = Url::parse(&format!("http://localhost:{proxy_port}")).unwrap(); 12 | 13 | let settings = attohttpc::ProxySettingsBuilder::new() 14 | .http_proxy(proxy_url.clone()) 15 | .https_proxy(proxy_url) 16 | .build(); 17 | 18 | let mut sess = attohttpc::Session::new(); 19 | sess.proxy_settings(settings); 20 | 21 | let resp = sess.get(remote_url).danger_accept_invalid_certs(true).send().unwrap(); 22 | 23 | assert_eq!(resp.text().unwrap(), "hello"); 24 | 25 | Ok(()) 26 | } 27 | 28 | #[cfg(any(feature = "tls-native", feature = "__rustls"))] 29 | #[tokio::test(flavor = "multi_thread")] 30 | async fn test_http_url_with_https_proxy() -> Result<(), anyhow::Error> { 31 | let remote_port = tools::start_hello_world_server(false).await?; 32 | let remote_url = format!("http://localhost:{remote_port}"); 33 | 34 | let proxy_port = tools::start_proxy_server(true).await?; 35 | let proxy_url = Url::parse(&format!("https://localhost:{proxy_port}")).unwrap(); 36 | 37 | let settings = attohttpc::ProxySettingsBuilder::new() 38 | .http_proxy(proxy_url.clone()) 39 | .https_proxy(proxy_url) 40 | .build(); 41 | 42 | let mut sess = attohttpc::Session::new(); 43 | sess.proxy_settings(settings); 44 | 45 | let resp = sess.get(remote_url).danger_accept_invalid_certs(true).send().unwrap(); 46 | 47 | assert_eq!(resp.text().unwrap(), "hello"); 48 | 49 | Ok(()) 50 | } 51 | 52 | #[cfg(any(feature = "tls-native", feature = "__rustls"))] 53 | #[tokio::test(flavor = "multi_thread")] 54 | async fn test_https_url_with_http_proxy() -> Result<(), anyhow::Error> { 55 | let remote_port = tools::start_hello_world_server(true).await?; 56 | let remote_url = format!("https://localhost:{remote_port}"); 57 | 58 | let proxy_port = tools::start_proxy_server(false).await?; 59 | let proxy_url = Url::parse(&format!("http://localhost:{proxy_port}")).unwrap(); 60 | 61 | let settings = attohttpc::ProxySettingsBuilder::new() 62 | .http_proxy(proxy_url.clone()) 63 | .https_proxy(proxy_url) 64 | .build(); 65 | 66 | let mut sess = attohttpc::Session::new(); 67 | sess.proxy_settings(settings); 68 | 69 | let resp = sess.get(remote_url).danger_accept_invalid_certs(true).send().unwrap(); 70 | 71 | assert_eq!(resp.text().unwrap(), "hello"); 72 | 73 | Ok(()) 74 | } 75 | 76 | #[cfg(any(feature = "tls-native", feature = "__rustls"))] 77 | #[tokio::test(flavor = "multi_thread")] 78 | async fn test_https_url_with_https_proxy() -> Result<(), anyhow::Error> { 79 | let remote_port = tools::start_hello_world_server(true).await?; 80 | let remote_url = format!("https://localhost:{remote_port}"); 81 | 82 | let proxy_port = tools::start_proxy_server(true).await?; 83 | let proxy_url = Url::parse(&format!("https://localhost:{proxy_port}")).unwrap(); 84 | 85 | let settings = attohttpc::ProxySettingsBuilder::new() 86 | .http_proxy(proxy_url.clone()) 87 | .https_proxy(proxy_url) 88 | .build(); 89 | 90 | let mut sess = attohttpc::Session::new(); 91 | sess.proxy_settings(settings); 92 | 93 | let resp = sess.get(remote_url).danger_accept_invalid_certs(true).send().unwrap(); 94 | 95 | assert_eq!(resp.status().as_u16(), 200); 96 | assert_eq!(resp.text().unwrap(), "hello"); 97 | 98 | Ok(()) 99 | } 100 | 101 | #[tokio::test(flavor = "multi_thread")] 102 | async fn test_http_url_with_http_proxy_refusal() -> Result<(), anyhow::Error> { 103 | let proxy_port = tools::start_refusing_proxy_server(false).await?; 104 | let proxy_url = Url::parse(&format!("http://localhost:{proxy_port}")).unwrap(); 105 | 106 | let settings = attohttpc::ProxySettingsBuilder::new() 107 | .http_proxy(proxy_url.clone()) 108 | .https_proxy(proxy_url) 109 | .build(); 110 | 111 | let mut sess = attohttpc::Session::new(); 112 | sess.proxy_settings(settings); 113 | 114 | let resp = sess 115 | .get("http://localhost") 116 | .danger_accept_invalid_certs(true) 117 | .send() 118 | .unwrap(); 119 | 120 | assert_eq!(resp.status().as_u16(), 400); 121 | assert_eq!(resp.text().unwrap(), "bad request"); 122 | 123 | Ok(()) 124 | } 125 | 126 | #[tokio::test(flavor = "multi_thread")] 127 | async fn test_https_url_with_http_proxy_refusal() -> Result<(), anyhow::Error> { 128 | let proxy_port = tools::start_refusing_proxy_server(false).await?; 129 | let proxy_url = Url::parse(&format!("http://localhost:{proxy_port}")).unwrap(); 130 | 131 | let settings = attohttpc::ProxySettingsBuilder::new() 132 | .http_proxy(proxy_url.clone()) 133 | .https_proxy(proxy_url) 134 | .build(); 135 | 136 | let mut sess = attohttpc::Session::new(); 137 | sess.proxy_settings(settings); 138 | 139 | let res = sess.get("https://localhost").danger_accept_invalid_certs(true).send(); 140 | 141 | let err = res.err().unwrap(); 142 | match err.kind() { 143 | attohttpc::ErrorKind::ConnectError { status_code, body } => { 144 | assert_eq!(status_code.as_u16(), 400); 145 | assert_eq!(body, b"bad request"); 146 | } 147 | _ => panic!("wrong error"), 148 | } 149 | 150 | Ok(()) 151 | } 152 | 153 | #[cfg(any(feature = "tls-native", feature = "__rustls"))] 154 | #[tokio::test(flavor = "multi_thread")] 155 | async fn test_http_url_with_https_proxy_refusal() -> Result<(), anyhow::Error> { 156 | let proxy_port = tools::start_refusing_proxy_server(true).await?; 157 | let proxy_url = Url::parse(&format!("https://localhost:{proxy_port}")).unwrap(); 158 | 159 | let settings = attohttpc::ProxySettingsBuilder::new() 160 | .http_proxy(proxy_url.clone()) 161 | .https_proxy(proxy_url) 162 | .build(); 163 | 164 | let mut sess = attohttpc::Session::new(); 165 | sess.proxy_settings(settings); 166 | 167 | let resp = sess 168 | .get("http://localhost") 169 | .danger_accept_invalid_certs(true) 170 | .send() 171 | .unwrap(); 172 | 173 | assert_eq!(resp.status().as_u16(), 400); 174 | assert_eq!(resp.text().unwrap(), "bad request"); 175 | 176 | Ok(()) 177 | } 178 | 179 | #[cfg(any(feature = "tls-native", feature = "__rustls"))] 180 | #[tokio::test(flavor = "multi_thread")] 181 | async fn test_https_url_with_https_proxy_refusal() -> Result<(), anyhow::Error> { 182 | let proxy_port = tools::start_refusing_proxy_server(true).await?; 183 | let proxy_url = Url::parse(&format!("https://localhost:{proxy_port}")).unwrap(); 184 | 185 | let settings = attohttpc::ProxySettingsBuilder::new() 186 | .http_proxy(proxy_url.clone()) 187 | .https_proxy(proxy_url) 188 | .build(); 189 | 190 | let mut sess = attohttpc::Session::new(); 191 | sess.proxy_settings(settings); 192 | 193 | let res = sess.get("https://localhost").danger_accept_invalid_certs(true).send(); 194 | 195 | let err = res.err().unwrap(); 196 | match err.kind() { 197 | attohttpc::ErrorKind::ConnectError { status_code, body } => { 198 | assert_eq!(status_code.as_u16(), 400); 199 | assert_eq!(body, b"bad request"); 200 | } 201 | _ => panic!("wrong error: {}", err), 202 | } 203 | 204 | Ok(()) 205 | } 206 | -------------------------------------------------------------------------------- /tests/test_redirection.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | 3 | use attohttpc::ErrorKind; 4 | use axum::body::Body; 5 | use axum::response::Response; 6 | use axum::routing::get; 7 | use axum::Router; 8 | use http::StatusCode; 9 | 10 | async fn make_server() -> Result { 11 | let addr = SocketAddr::from(([127, 0, 0, 1], 0)); 12 | let incoming = tokio::net::TcpListener::bind(&addr).await?; 13 | let local_addr = incoming.local_addr()?; 14 | 15 | async fn x301() -> Response { 16 | Response::builder() 17 | .status(StatusCode::MOVED_PERMANENTLY) 18 | .header("Location", "/301") 19 | .body(Body::from("")) 20 | .unwrap() 21 | } 22 | 23 | async fn x304() -> Response { 24 | Response::builder() 25 | .status(StatusCode::NOT_MODIFIED) 26 | .body(Body::from("")) 27 | .unwrap() 28 | } 29 | 30 | let app = Router::new().route("/301", get(x301)).route("/304", get(x304)); 31 | 32 | tokio::spawn(async move { 33 | axum::serve(incoming, app).await.unwrap(); 34 | }); 35 | 36 | Ok(local_addr.port()) 37 | } 38 | 39 | #[tokio::test(flavor = "multi_thread")] 40 | async fn test_redirection_default() -> Result<(), anyhow::Error> { 41 | let port = make_server().await?; 42 | 43 | match attohttpc::get(format!("http://localhost:{port}/301")).send() { 44 | Err(err) => match err.kind() { 45 | ErrorKind::TooManyRedirections => (), 46 | _ => panic!(), 47 | }, 48 | _ => panic!(), 49 | } 50 | 51 | Ok(()) 52 | } 53 | 54 | #[tokio::test(flavor = "multi_thread")] 55 | async fn test_redirection_0() -> Result<(), anyhow::Error> { 56 | let port = make_server().await?; 57 | 58 | match attohttpc::get(format!("http://localhost:{port}/301")) 59 | .max_redirections(0) 60 | .send() 61 | { 62 | Err(err) => match err.kind() { 63 | ErrorKind::TooManyRedirections => (), 64 | _ => panic!(), 65 | }, 66 | _ => panic!(), 67 | } 68 | 69 | Ok(()) 70 | } 71 | 72 | #[tokio::test(flavor = "multi_thread")] 73 | async fn test_redirection_disallowed() -> Result<(), anyhow::Error> { 74 | let port = make_server().await?; 75 | 76 | let resp = attohttpc::get(format!("http://localhost:{port}/301")) 77 | .follow_redirects(false) 78 | .send() 79 | .unwrap(); 80 | 81 | assert!(resp.status().is_redirection()); 82 | 83 | Ok(()) 84 | } 85 | 86 | #[tokio::test(flavor = "multi_thread")] 87 | async fn test_redirection_not_redirect() -> Result<(), anyhow::Error> { 88 | let port = make_server().await?; 89 | 90 | match attohttpc::get(format!("http://localhost:{port}/304")).send() { 91 | Ok(_) => (), 92 | _ => panic!(), 93 | } 94 | 95 | Ok(()) 96 | } 97 | -------------------------------------------------------------------------------- /tests/test_timeout.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::net::TcpListener; 3 | use std::thread; 4 | use std::time::Duration; 5 | 6 | #[test] 7 | fn request_fails_due_to_read_timeout() { 8 | let listener = TcpListener::bind("localhost:0").unwrap(); 9 | let port = listener.local_addr().unwrap().port(); 10 | let thread = thread::spawn(move || { 11 | let _stream = listener.accept().unwrap(); 12 | thread::sleep(Duration::from_millis(500)); 13 | }); 14 | 15 | let result = attohttpc::get(format!("http://localhost:{port}")) 16 | .read_timeout(Duration::from_millis(100)) 17 | .send(); 18 | 19 | match result { 20 | Err(err) => match err.kind() { 21 | attohttpc::ErrorKind::Io(err) => match err.kind() { 22 | io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => (), 23 | err => panic!("Unexpected I/O error: {:?}", err), 24 | }, 25 | err => panic!("Unexpected error: {:?}", err), 26 | }, 27 | Ok(resp) => panic!("Unexpected response: {:?}", resp), 28 | } 29 | 30 | thread.join().unwrap(); 31 | } 32 | 33 | #[test] 34 | fn request_fails_due_to_timeout() { 35 | let listener = TcpListener::bind("localhost:0").unwrap(); 36 | let port = listener.local_addr().unwrap().port(); 37 | let thread = thread::spawn(move || { 38 | let _stream = listener.accept().unwrap(); 39 | thread::sleep(Duration::from_millis(500)); 40 | }); 41 | 42 | let result = attohttpc::get(format!("http://localhost:{port}")) 43 | .timeout(Duration::from_millis(100)) 44 | .send(); 45 | 46 | match result { 47 | Err(err) => match err.kind() { 48 | attohttpc::ErrorKind::Io(err) => match err.kind() { 49 | io::ErrorKind::TimedOut => (), 50 | err => panic!("Unexpected I/O error: {:?}", err), 51 | }, 52 | err => panic!("Unexpected error: {:?}", err), 53 | }, 54 | Ok(resp) => panic!("Unexpected response: {:?}", resp), 55 | } 56 | 57 | thread.join().unwrap(); 58 | } 59 | -------------------------------------------------------------------------------- /tests/tools/cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIFCTCCAvGgAwIBAgIUX5ZepTII1Wps0xuXJDA2eM27L0YwDQYJKoZIhvcNAQEL 3 | BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIwMTAwNTA0MDUwOFoXDTIzMDcw 4 | MjA0MDUwOFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF 5 | AAOCAg8AMIICCgKCAgEA2qT+tjMRjDqlAzNm1DFRwpQF7j02rORJjlZvM4bZdHVg 6 | cNV/pvo7fS6q11gaQj6nBSD4Wokczr1G2+3xBIkhvCeewWpujOiOYFx+1MPxYwrr 7 | gzXZmCHKrTlKCPRA+5QbL9WjN0O7Ccg2QrOX5VPR4LGrHR0IC7kcVUiXg0V9yZKb 8 | O1A/vt8V+CB26NXDZ2Up/McF7lU76aMsw/EX8vZsjJ4vvJKC3XR8kwT/iSbHbNto 9 | CM28RDx+N3JbZ5ivkkgt/sxTQsMK09qZR0YcARS+Ya7T4VeB7dAJGXJ5STbm4lJR 10 | ozTSifoC+T49Q7vhN9qGuuDjrnQNP+6zZEvd5pDmW1187qXz/k+a1zmxkcNWEh0U 11 | ehH8TYm0U8NhSo4x7n8wSH1gKZ3HpGaAiv8ZoryVXGMUKxDWTTSY99p766nWeDFa 12 | LTCiMYRUogjTAo6GOvHVV6HcK2xAaEjmWpQEITWGsU4hAIfD/PnIc7CtRmxalgBj 13 | oWuluFu0ZuzGSrmOm/ug04iVMG1PV8Yd/L5Gt2WUgNexO+SzrJG4jULi1Elage2W 14 | RsYgBXT1Wh8DFC4BBEkNDRqUkhRQohZ1LFT9m28O1/KYv2WTnRhtTEXCPfmwRNsZ 15 | AWNtvcsgai92gaO/ZAvSW/dOk3cytcOyDJN6nxdr6imJJcOMtzfry5+Tz3iW2UcC 16 | AwEAAaNTMFEwHQYDVR0OBBYEFGvVMxb9rBwTpwXa2oxFtdnlqzgYMB8GA1UdIwQY 17 | MBaAFGvVMxb9rBwTpwXa2oxFtdnlqzgYMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZI 18 | hvcNAQELBQADggIBAKXZtGc/Sn29wFIpJjcz/EMZYSeiYD2MkKtxJI3AigbTNVA+ 19 | qRof+I6ygJsX8Fz4oBSy/SeiYCOBZG7SfXVNi//LZvfyFKJlCEeqazvwa0B4zQeI 20 | Iza2VohaVUo4w4Rr/lc3HI0gsIzD+BzOv4+1Jn1PyzIdL50SNGlADwsZYuA1Us2b 21 | PLksztRU+c6C/sWEZXTLdAia2xSg+qJGimJ0Edk8FZN2XV0d9KblQh5nyQ7Ax+Rj 22 | vRYNzSwV8sPkBFonMj2ZM5jepNo7wnQm8L1GKqfpViJ+8n9KHUwbTOGQSeX2UJyq 23 | db4ECLWnChDkiuFRFLPJuYMTK1dYcIsXFYITw7MmrrzCdCvFDzlRHXRb8Jjfyf3P 24 | 2unCR82l7NvyTnrZFjCXr4Rkmk/ERhQxswrPS3VCVNGT0va4HxfnqOWruDz874C4 25 | mgBggBDO38MnhIfaE4xwHYcxyHvKb0DAiafc44trTM5tVqZhcFQ+fs8JIgDrJ0CY 26 | qZyrBJQ3g0Wti8wGOQ9E5SFtyZDFHTOgO2qjs8ljPI6+35aRuMK0LmnYKd11fXQM 27 | hYvwTYyw/gv5aoqTdUrVbQYe2qpNbnQaShgXqbmsIYVtSsrE1U7Y4TuEi0xmRE4U 28 | 4GN6fTyJNVXVpnqp/ef+V7go9DpRIEpmMGhzYhh3QhtpTO9hCyQ6d0qUUctj 29 | -----END CERTIFICATE----- 30 | -------------------------------------------------------------------------------- /tests/tools/generate-certs.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # https://stackoverflow.com/a/10176685 4 | openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 1000 -nodes -subj '/CN=localhost' 5 | -------------------------------------------------------------------------------- /tests/tools/key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIJRAIBADANBgkqhkiG9w0BAQEFAASCCS4wggkqAgEAAoICAQDapP62MxGMOqUD 3 | M2bUMVHClAXuPTas5EmOVm8zhtl0dWBw1X+m+jt9LqrXWBpCPqcFIPhaiRzOvUbb 4 | 7fEEiSG8J57Bam6M6I5gXH7Uw/FjCuuDNdmYIcqtOUoI9ED7lBsv1aM3Q7sJyDZC 5 | s5flU9HgsasdHQgLuRxVSJeDRX3Jkps7UD++3xX4IHbo1cNnZSn8xwXuVTvpoyzD 6 | 8Rfy9myMni+8koLddHyTBP+JJsds22gIzbxEPH43cltnmK+SSC3+zFNCwwrT2plH 7 | RhwBFL5hrtPhV4Ht0AkZcnlJNubiUlGjNNKJ+gL5Pj1Du+E32oa64OOudA0/7rNk 8 | S93mkOZbXXzupfP+T5rXObGRw1YSHRR6EfxNibRTw2FKjjHufzBIfWApncekZoCK 9 | /xmivJVcYxQrENZNNJj32nvrqdZ4MVotMKIxhFSiCNMCjoY68dVXodwrbEBoSOZa 10 | lAQhNYaxTiEAh8P8+chzsK1GbFqWAGOha6W4W7Rm7MZKuY6b+6DTiJUwbU9Xxh38 11 | vka3ZZSA17E75LOskbiNQuLUSVqB7ZZGxiAFdPVaHwMULgEESQ0NGpSSFFCiFnUs 12 | VP2bbw7X8pi/ZZOdGG1MRcI9+bBE2xkBY229yyBqL3aBo79kC9Jb906TdzK1w7IM 13 | k3qfF2vqKYklw4y3N+vLn5PPeJbZRwIDAQABAoICABmyBlU3Tzqhsb7cjZsmaKtr 14 | Zf8HpqNO+O4Hbm4pzOiI2tqn3KatBczCbUV9jyh14H8Tztwk/QO27trt7cNQsmxA 15 | Y8HrVi6tLf5ZIHa0yGuRXvg+neHpJr6Q2wiOXAy07TUD8HiQNy0BII9aN+G1pha4 16 | CpoTTVzDVncXXucIha3o0uF3nuY5pYVRm2flp19BXlvcr+/AiJf+m0Yz0VK2by9f 17 | b9DDjymjhPhqP/XIXuuGJYXTV+rVCShPAjr8kLWqUe196dMQxPb828KqlKPz4iPs 18 | ve3fsN9X8ZnSADsAfsXeUsXRcWTRwwOzstOBAwjjX64au+dGsD/xN5fz6pIsOUYP 19 | 5kYz+j8js3t9MsVo5bxJKy97F9tQTstWIjzhuwHXP3H0REmK+PRBTlaDYRjb5GGZ 20 | WsD++A6qht+gAuJIN7Ju0OKhvekimCAsLs1qBz+fMzU6UvumWBppp85QBM/VAc7C 21 | VQJq9IIuljHTkk/xODBU+bVgkG1dUviqoZsk8Xrp7+lOok3rq+gZvKyviYxJOhKE 22 | aftuVWt4UBo83AkHNqFf0sKw46nQVHjWVm7P9W5qLXeGihANAnsRe+AT0295B8L6 23 | BUkUW8Ls1xw1NBgpuYawoM1N+KrZzUTDcxEu/2WVLx0aVjAEgVgR28JNLGydpCDK 24 | wrQb2inqgkl+YJrOQE4RAoIBAQDz2T5e6adtMWguyQ2iO2zbEUKEBPn3CE9I/d8T 25 | DHYrMz/cvlNNARbrvBwr2fBXi5d2GWoxzl2Oo/+r+rgaCbu5QgvTAzRjolHbUa7R 26 | bksEBO0xOeKZh5pd3Zu3D8PqwOd/v408xeEOMOPJuM6ksMmg+qXi08IDmNrhJQOu 27 | YlfawiL98o1DMIXhVqqefQx+X7k3JPMEAF3pMqzgA62teQKRdrwbnREIpU4CU/9H 28 | UWHpZa1JJObrJUwfhucrWz8v3v1py+7rfjwJ1MjYThlDSzz8LO3dTZ7ZZvuCgV5U 29 | 9oWL8uCG7dag9CNYKxOKYo6vYnEhTUVJ5OWImXSe/RXJqabNAoIBAQDlijmK5c6L 30 | G53zxBEi7WcBjFX9zRm7rydY5gMMRIdX3iYWBWzGnxdou4f2bX22bipPjVJlnEG6 31 | HVDD3CGTlu9weN8icG0fl0xg2+GSQsJxvevRe5id8q4VgVWjFTVpkCZWIgHRtYuy 32 | aPLTmxLRblEnTYQiUf91ramB8tnwGKkPBt3LMlcu2I/lBMwZ3RWjY3Rh+izQc7Iu 33 | nqdiVK8+GB8WUZTINvL1+bjsYUQnF+xO5925UasPBsVyR37wvyr48m3rYCfWqC62 34 | 6ZReWoEv22lA28q07PkrjgEqVl8iW/Y2U71Vz5u1bScWYHpNAjNN4xoCEYCqG4/Z 35 | 3BQa6PvMeLhjAoIBAQDdMk6ymC8pDO1dq0FzzYYUPlbhHQaGgQIyyx/kzqQO0gF3 36 | QyqbKobZNh0z5hFvyI+PMMS1A2a4sqmFHvyLmhKwnWnOcxS94ItycGktT2g68CEd 37 | S03zSR/NT+4lo+Zrd6tHpcH/w9xAHDc1UDAbEscEAPojhUV0L9805netY0YgiYRm 38 | afScY1K+svg4pmSD6l8/14OeOJr3o+FPn+aW+rNKsrM2NUczWYgm8n0Z+4feyT3k 39 | oNXbrkycOKpQGQh7a0LjCnvjIPJjVKMQG+Asu/5JHpRy8CtNOG6j1qh0V9+SsLHd 40 | k7fi7zN0kja6jGa0T48hz08wFbPlMpaodhmiggNtAoIBAQDFWG6ajXNjQ+4r+yDk 41 | j4kk2pu/5ks+gptoofBy5qQaFmNWQXnA1lzQ7ZI6eu4/Hz3/QThbJdmYgsXCborr 42 | kjPv1eu5d0FKjNDc4Y5xKjZ0hd2uL/4JpfCnipeu63gNdyKOoRRBUT12uSY0abnn 43 | F1psWQfGXgstCI6Yq8tV0k3fHn8nyPMN0qD8PIChp+OHouUXq3hGC4Jg9IRiVP6J 44 | e0GW9bLd4+hFNFsZ15I5ria7vDtzrdRMyfyNgQyRPTpb6Eo6uGPn+JaZKf8mk9dH 45 | 5vL6ET2ZCbg8Cw3TcYi/SKNJF/vLP//WHq8hXPcpKEAlsZ/VsWiWV7X4j281p6N8 46 | qzmTAoIBAQCCp7uwFx1iPH4YRcY3UeWupbQWvbdWDjOzvteUtKiZnYTiSO1vfY+a 47 | 50OSOTB9Ii/5xtmuRlhdk2eAomJA5RjVGxlSBGyoAK1AtU1jVkQS14e62JD5c7W6 48 | 1N2yR0VslZWv98UpnKZqTT3iNMsyX4dEBx5817uDS1RJl9VykYw2bLpwrQP/8Tw4 49 | VnGkoGYdSPWkq4jyiDjmDjsDLcGfdvwjnufqJZQycWvy9eUsVZKCEunCk9dlg89/ 50 | kfvvTZtCQfuW0onL3909DTW3RJa8cKhK0ipSighiH99RlCs2WqZfY5ej0aclotcw 51 | yhKlH++rN2YnKCSg9iSJ2Og69kNvqzIB 52 | -----END PRIVATE KEY----- 53 | -------------------------------------------------------------------------------- /tests/tools/mod.rs: -------------------------------------------------------------------------------- 1 | mod proxy; 2 | mod servers; 3 | 4 | pub use proxy::*; 5 | pub use servers::*; 6 | -------------------------------------------------------------------------------- /tests/tools/proxy.rs: -------------------------------------------------------------------------------- 1 | // This code has been taken from the hyper project and slightly modified: https://github.com/hyperium/hyper/blob/master/examples/http_proxy.rs 2 | // It's needed to create a proxy server for testing. 3 | 4 | use axum_server::tls_rustls::RustlsConfig; 5 | use bytes::Bytes; 6 | use http::{Method, Request, Response}; 7 | use http_body_util::combinators::BoxBody; 8 | use http_body_util::{BodyExt, Empty, Full}; 9 | use hyper::client::conn::http1::Builder as ClientBuilder; 10 | use hyper::service::service_fn; 11 | use hyper::upgrade::Upgraded; 12 | use hyper_util::rt::{TokioExecutor, TokioIo}; 13 | use hyper_util::server::conn::auto::Builder; 14 | use std::net::SocketAddr; 15 | use tokio::net::{TcpListener, TcpStream}; 16 | use tokio_rustls::TlsAcceptor; 17 | 18 | pub async fn start_proxy_server(tls: bool) -> anyhow::Result { 19 | create_proxy(tls, false).await 20 | } 21 | 22 | pub async fn start_refusing_proxy_server(tls: bool) -> anyhow::Result { 23 | create_proxy(tls, true).await 24 | } 25 | 26 | // Code below is derived from these examples: 27 | // Hyper proxy: https://github.com/hyperium/hyper/blob/master/examples/http_proxy.rs 28 | // Hyper TLS server: https://github.com/rustls/hyper-rustls/blob/main/examples/server.rs 29 | 30 | async fn proxy_allow( 31 | req: Request, 32 | ) -> Result>, hyper::Error> { 33 | if Method::CONNECT == req.method() { 34 | if let Some(addr) = host_addr(req.uri()) { 35 | tokio::task::spawn(async move { 36 | match hyper::upgrade::on(req).await { 37 | Ok(upgraded) => { 38 | if let Err(e) = tunnel(upgraded, addr).await { 39 | eprintln!("server io error: {}", e); 40 | }; 41 | } 42 | Err(e) => eprintln!("upgrade error: {}", e), 43 | } 44 | }); 45 | 46 | Ok(Response::new(empty())) 47 | } else { 48 | eprintln!("CONNECT host is not socket addr: {:?}", req.uri()); 49 | let mut resp = Response::new(full("CONNECT must be to a socket address")); 50 | *resp.status_mut() = http::StatusCode::BAD_REQUEST; 51 | Ok(resp) 52 | } 53 | } else { 54 | let host = req.uri().host().expect("uri has no host"); 55 | let port = req.uri().port_u16().unwrap_or(80); 56 | 57 | let stream = TcpStream::connect((host, port)).await.unwrap(); 58 | let io = TokioIo::new(stream); 59 | 60 | let (mut sender, conn) = ClientBuilder::new() 61 | .preserve_header_case(true) 62 | .title_case_headers(true) 63 | .handshake(io) 64 | .await?; 65 | tokio::task::spawn(async move { 66 | if let Err(err) = conn.await { 67 | println!("Connection failed: {:?}", err); 68 | } 69 | }); 70 | 71 | let resp = sender.send_request(req).await?; 72 | Ok(resp.map(|b| b.boxed())) 73 | } 74 | } 75 | 76 | async fn proxy_deny( 77 | _req: Request, 78 | ) -> Result>, hyper::Error> { 79 | let mut resp = Response::new(full("bad request")); 80 | *resp.status_mut() = http::StatusCode::BAD_REQUEST; 81 | Ok(resp) 82 | } 83 | 84 | fn host_addr(uri: &http::Uri) -> Option { 85 | uri.authority().map(|auth| auth.to_string()) 86 | } 87 | 88 | fn empty() -> BoxBody { 89 | Empty::::new().map_err(|never| match never {}).boxed() 90 | } 91 | 92 | fn full>(chunk: T) -> BoxBody { 93 | Full::new(chunk.into()).map_err(|never| match never {}).boxed() 94 | } 95 | 96 | async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> { 97 | let mut server = TcpStream::connect(addr).await?; 98 | let mut upgraded = TokioIo::new(upgraded); 99 | tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?; 100 | Ok(()) 101 | } 102 | 103 | async fn create_proxy(tls: bool, deny: bool) -> anyhow::Result { 104 | let addr = SocketAddr::from(([127, 0, 0, 1], 0)); 105 | let listener = TcpListener::bind(addr).await?; 106 | let port = listener.local_addr().unwrap().port(); 107 | 108 | if tls { 109 | let config = RustlsConfig::from_pem(include_bytes!("cert.pem").to_vec(), include_bytes!("key.pem").to_vec()) 110 | .await 111 | .unwrap(); 112 | let tls_acceptor = TlsAcceptor::from(config.get_inner()); 113 | 114 | tokio::spawn(async move { 115 | loop { 116 | let (tcp_stream, _remote_addr) = listener.accept().await.unwrap(); 117 | let tls_acceptor = tls_acceptor.clone(); 118 | tokio::spawn(async move { 119 | let tls_stream = match tls_acceptor.accept(tcp_stream).await { 120 | Ok(tls_stream) => tls_stream, 121 | Err(err) => { 122 | eprintln!("failed to perform tls handshake: {err:#}"); 123 | return; 124 | } 125 | }; 126 | if let Err(err) = Builder::new(TokioExecutor::new()) 127 | .serve_connection_with_upgrades( 128 | TokioIo::new(tls_stream), 129 | service_fn(move |req| async move { 130 | match deny { 131 | true => proxy_deny(req).await, 132 | false => proxy_allow(req).await, 133 | } 134 | }), 135 | ) 136 | .await 137 | { 138 | eprintln!("failed to serve connection: {err:#}"); 139 | } 140 | }); 141 | } 142 | }); 143 | } else { 144 | tokio::spawn(async move { 145 | loop { 146 | let (tcp_stream, _remote_addr) = listener.accept().await.unwrap(); 147 | tokio::spawn(async move { 148 | if let Err(err) = Builder::new(TokioExecutor::new()) 149 | .serve_connection_with_upgrades( 150 | TokioIo::new(tcp_stream), 151 | service_fn(move |req| async move { 152 | match deny { 153 | true => proxy_deny(req).await, 154 | false => proxy_allow(req).await, 155 | } 156 | }), 157 | ) 158 | .await 159 | { 160 | eprintln!("failed to serve connection: {err:#}"); 161 | } 162 | }); 163 | } 164 | }); 165 | } 166 | 167 | Ok(port) 168 | } 169 | -------------------------------------------------------------------------------- /tests/tools/servers.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | 3 | use axum::body::Body; 4 | use axum::http::StatusCode; 5 | use axum::response::Response; 6 | use axum::routing::get; 7 | use axum::Router; 8 | use axum_server::tls_rustls::{from_tcp_rustls, RustlsConfig}; 9 | 10 | pub async fn start_hello_world_server(tls: bool) -> anyhow::Result { 11 | let addr = SocketAddr::from(([127, 0, 0, 1], 0)); 12 | let incoming = tokio::net::TcpListener::bind(&addr).await?; 13 | let local_addr = incoming.local_addr()?; 14 | 15 | async fn hello_world() -> Response { 16 | Response::builder() 17 | .status(StatusCode::OK) 18 | .body(Body::from("hello")) 19 | .unwrap() 20 | } 21 | 22 | let app = Router::new().route("/", get(hello_world)); 23 | 24 | if tls { 25 | let config = RustlsConfig::from_pem(include_bytes!("cert.pem").to_vec(), include_bytes!("key.pem").to_vec()) 26 | .await 27 | .unwrap(); 28 | 29 | tokio::spawn(async move { 30 | from_tcp_rustls(incoming.into_std().unwrap(), config) 31 | .serve(app.into_make_service()) 32 | .await 33 | .unwrap(); 34 | }); 35 | } else { 36 | tokio::spawn(async move { 37 | axum::serve(incoming, app).await.unwrap(); 38 | }); 39 | } 40 | 41 | Ok(local_addr.port()) 42 | } 43 | -------------------------------------------------------------------------------- /tools/clippy.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -Eeuxo pipefail 3 | 4 | cargo clippy --all-features --all-targets -- --deny warnings 5 | -------------------------------------------------------------------------------- /tools/tests.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -Eeuxo pipefail 3 | 4 | unset http_proxy https_proxy no_proxy 5 | unset HTTP_PROXY HTTPS_PROXY NO_PROXY 6 | 7 | if [[ "${CI:-}" == "true" ]] ; then 8 | mkdir -p .cargo 9 | echo "[term]" >> .cargo/config.toml 10 | echo "color = 'always'" >> .cargo/config.toml 11 | fi 12 | 13 | function testwrap { 14 | if which cargo-nextest ; then 15 | cargo nextest run "$@" 16 | else 17 | cargo test "$@" 18 | fi 19 | } 20 | 21 | testwrap 22 | testwrap --all-features 23 | testwrap --no-default-features 24 | testwrap --no-default-features --features basic-auth 25 | testwrap --no-default-features --features charsets 26 | testwrap --no-default-features --features compress 27 | testwrap --no-default-features --features compress-zlib 28 | testwrap --no-default-features --features compress-zlib-ng 29 | testwrap --no-default-features --features form 30 | testwrap --no-default-features --features multipart-form 31 | testwrap --no-default-features --features json 32 | testwrap --no-default-features --features tls-native 33 | testwrap --no-default-features --features tls-native,tls-native-vendored 34 | testwrap --no-default-features --features tls-rustls-webpki-roots 35 | testwrap --no-default-features --features tls-rustls-native-roots 36 | --------------------------------------------------------------------------------