├── mendes ├── templates │ └── hello.html ├── src │ ├── utils.rs │ ├── lib.rs │ ├── key.rs │ ├── hyper.rs │ ├── cookies.rs │ ├── body.rs │ ├── forms.rs │ ├── multipart.rs │ └── application.rs ├── tests │ ├── forms.rs │ ├── readme.rs │ ├── hyper.rs │ ├── cookies.rs │ ├── body.rs │ └── basic.rs └── Cargo.toml ├── .github ├── FUNDING.yml ├── dependabot.yml └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── codecov.yml ├── deny.toml ├── mendes-macros ├── Cargo.toml └── src │ ├── lib.rs │ ├── cookies.rs │ ├── route.rs │ └── forms.rs ├── LICENSE-MIT ├── README.md └── LICENSE-APACHE /mendes/templates/hello.html: -------------------------------------------------------------------------------- 1 | Hello, {{ name }}! 2 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [djc] 2 | patreon: dochtman 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | .vscode 5 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = ["mendes", "mendes-macros"] 3 | resolver = "2" 4 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | ignore: 2 | - "mendes-macros" # coverage of procedural macros has not been set up 3 | -------------------------------------------------------------------------------- /deny.toml: -------------------------------------------------------------------------------- 1 | [licenses] 2 | allow = [ 3 | "Apache-2.0", 4 | "BSD-3-Clause", 5 | "ISC", 6 | "MIT", 7 | "OpenSSL", 8 | "Unicode-3.0", 9 | ] 10 | 11 | [[licenses.clarify]] 12 | name = "ring" 13 | expression = "ISC AND MIT AND OpenSSL" 14 | license-files = [{ path = "LICENSE", hash = 0xbd0eed23 }] 15 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: cargo 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | time: "04:00" 8 | open-pull-requests-limit: 10 9 | ignore: 10 | - dependency-name: tokio 11 | versions: 12 | - 1.1.0 13 | - 1.1.1 14 | - package-ecosystem: github-actions 15 | directory: "/" 16 | schedule: 17 | interval: weekly 18 | -------------------------------------------------------------------------------- /mendes-macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mendes-macros" 3 | version = "0.4.0" 4 | edition = "2021" 5 | rust-version = "1.63" 6 | description = "Macros for mendes web toolkit" 7 | documentation = "https://docs.rs/mendes-macros" 8 | repository = "https://github.com/djc/mendes" 9 | license = "MIT OR Apache-2.0" 10 | workspace = ".." 11 | 12 | [lib] 13 | proc-macro = true 14 | 15 | [dependencies] 16 | quote = "1.0.2" 17 | syn = { version = "2", features = ["full"] } 18 | proc-macro2 = "1.0.8" 19 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Dirkjan Ochtman 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /mendes/src/utils.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "static")] 2 | mod file_mod { 3 | use crate::application::Error; 4 | use http::header::{CONTENT_LENGTH, CONTENT_TYPE}; 5 | use http::StatusCode; 6 | use std::path::PathBuf; 7 | use tokio::fs; 8 | 9 | pub async fn file(mut path: PathBuf) -> Result, Error> 10 | where 11 | B: From>, 12 | { 13 | let mut metadata = fs::metadata(&path).await.map_err(|_| Error::FileNotFound)?; 14 | if metadata.is_dir() { 15 | path = path.join("index.html"); 16 | metadata = fs::metadata(&path).await.map_err(|_| Error::FileNotFound)?; 17 | } 18 | 19 | let mut builder = http::Response::builder() 20 | .status(StatusCode::OK) 21 | .header(CONTENT_LENGTH, metadata.len()); 22 | 23 | if let Some(mime) = mime_guess::from_path(&path).first() { 24 | builder = builder.header(CONTENT_TYPE, mime.to_string()); 25 | } 26 | 27 | let bytes = fs::read(path).await.map_err(|_| Error::FileNotFound)?; 28 | Ok(builder.body(B::from(bytes)).unwrap()) 29 | } 30 | } 31 | 32 | #[cfg(feature = "static")] 33 | #[cfg_attr(docsrs, doc(cfg(feature = "static")))] 34 | pub use file_mod::file; 35 | -------------------------------------------------------------------------------- /mendes/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(docsrs, feature(doc_cfg))] 2 | 3 | #[cfg(feature = "http")] 4 | #[cfg_attr(docsrs, doc(cfg(feature = "http")))] 5 | /// Re-export of the http crate 6 | pub use http; 7 | 8 | #[cfg(feature = "application")] 9 | #[cfg_attr(docsrs, doc(cfg(feature = "application")))] 10 | /// Core of the Mendes web application toolkit 11 | pub mod application; 12 | #[cfg(feature = "application")] 13 | pub use application::{handler, route, scope, Application, Context, Error, FromContext}; 14 | 15 | #[cfg(feature = "application")] 16 | #[cfg_attr(docsrs, doc(cfg(feature = "application")))] 17 | pub mod body; 18 | #[cfg(feature = "application")] 19 | pub use body::Body; 20 | 21 | #[cfg(feature = "cookies")] 22 | #[cfg_attr(docsrs, doc(cfg(feature = "cookies")))] 23 | /// Cookie support 24 | pub mod cookies; 25 | 26 | #[cfg(feature = "key")] 27 | #[cfg_attr(docsrs, doc(cfg(feature = "key")))] 28 | /// AEAD encryption/decryption support 29 | pub mod key; 30 | 31 | #[cfg(feature = "forms")] 32 | #[cfg_attr(docsrs, doc(cfg(feature = "forms")))] 33 | /// Form generation and data validation 34 | pub mod forms; 35 | 36 | /// Some helperrs 37 | pub mod utils; 38 | 39 | #[cfg(feature = "hyper")] 40 | #[cfg_attr(docsrs, doc(cfg(feature = "hyper")))] 41 | /// Optional features that require hyper 42 | pub mod hyper; 43 | 44 | #[cfg(feature = "uploads")] 45 | mod multipart; 46 | 47 | /// Some content type definitions 48 | pub mod types { 49 | pub const HTML: &str = "text/html"; 50 | pub const JSON: &str = "application/json"; 51 | } 52 | -------------------------------------------------------------------------------- /mendes/tests/forms.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "forms")] 2 | 3 | use std::borrow::Cow; 4 | 5 | use mendes::forms::{form, ToField, ToForm}; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | #[test] 9 | fn test_generate() { 10 | let form = SomeForm::to_form(); 11 | let form = form.set("name", "hi").unwrap(); 12 | let html = form.to_string(); 13 | assert!(!html.contains("skipped")); 14 | } 15 | 16 | #[test] 17 | fn test_roundtrip() { 18 | let obj = SomeForm { 19 | skipped: 0, 20 | name: "name".into(), 21 | amount: 1, 22 | rate: 2.0, 23 | byte: 3, 24 | test: true, 25 | options: Options::Straight, 26 | #[cfg(feature = "chrono")] 27 | date: chrono::Utc::now().date_naive(), 28 | }; 29 | let s = serde_urlencoded::to_string(&obj).unwrap(); 30 | let decoded = serde_urlencoded::from_bytes(s.as_bytes()).unwrap(); 31 | assert_eq!(obj, decoded); 32 | } 33 | 34 | #[allow(dead_code)] 35 | #[form(action = "/assets/new", submit = "Create")] 36 | #[derive(Debug, Deserialize, Serialize, PartialEq)] 37 | struct SomeForm<'a> { 38 | #[form(skip)] 39 | skipped: u8, 40 | name: Cow<'a, str>, 41 | amount: u32, 42 | rate: f32, 43 | byte: u8, 44 | #[form(item = "Group")] 45 | test: bool, 46 | #[form(item = "Group")] 47 | options: Options, 48 | #[cfg(feature = "chrono")] 49 | date: chrono::NaiveDate, 50 | } 51 | 52 | #[derive(Debug, Deserialize, Serialize, ToField, PartialEq)] 53 | enum Options { 54 | Straight, 55 | #[option(label = "Relabeled")] 56 | Labeled, 57 | } 58 | -------------------------------------------------------------------------------- /mendes/tests/readme.rs: -------------------------------------------------------------------------------- 1 | #![cfg(all(feature = "application", feature = "hyper", feature = "body-util"))] 2 | 3 | use async_trait::async_trait; 4 | use bytes::Bytes; 5 | use http_body_util::Full; 6 | use mendes::application::IntoResponse; 7 | use mendes::http::request::Parts; 8 | use mendes::http::{Response, StatusCode}; 9 | use mendes::hyper::body::Incoming; 10 | use mendes::{handler, route, Application, Context}; 11 | 12 | #[handler(GET)] 13 | async fn hello(_: &App) -> Result>, Error> { 14 | Ok(Response::builder() 15 | .status(StatusCode::OK) 16 | .body("Hello, world".into()) 17 | .unwrap()) 18 | } 19 | 20 | struct App {} 21 | 22 | #[async_trait] 23 | impl Application for App { 24 | type RequestBody = Incoming; 25 | type ResponseBody = Full; 26 | type Error = Error; 27 | 28 | async fn handle(mut cx: Context) -> Response> { 29 | route!(match cx.path() { 30 | _ => hello, 31 | }) 32 | } 33 | } 34 | 35 | #[derive(Debug)] 36 | enum Error { 37 | Mendes(mendes::Error), 38 | } 39 | 40 | impl From for Error { 41 | fn from(e: mendes::Error) -> Self { 42 | Error::Mendes(e) 43 | } 44 | } 45 | 46 | impl From<&Error> for StatusCode { 47 | fn from(e: &Error) -> StatusCode { 48 | let Error::Mendes(e) = e; 49 | StatusCode::from(e) 50 | } 51 | } 52 | 53 | impl IntoResponse for Error { 54 | fn into_response(self, _: &App, _: &Parts) -> Response> { 55 | let Error::Mendes(err) = self; 56 | Response::builder() 57 | .status(StatusCode::from(&err)) 58 | .body(Full::new(Bytes::from(err.to_string()))) 59 | .unwrap() 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: ['main'] 6 | pull_request: 7 | 8 | jobs: 9 | test: 10 | strategy: 11 | matrix: 12 | os: [ubuntu-latest, macos-latest, windows-latest] 13 | rust: [stable, beta, 1.83] 14 | exclude: 15 | - os: macos-latest 16 | rust: beta 17 | - os: windows-latest 18 | rust: beta 19 | - os: macos-latest 20 | rust: 1.83 21 | - os: windows-latest 22 | rust: 1.83 23 | 24 | runs-on: ${{ matrix.os }} 25 | 26 | steps: 27 | - uses: actions/checkout@v6 28 | - uses: dtolnay/rust-toolchain@master 29 | with: 30 | toolchain: ${{ matrix.rust }} 31 | - run: cargo check --all-targets --all-features 32 | - run: cargo test --all-features 33 | 34 | default: 35 | runs-on: ubuntu-latest 36 | steps: 37 | - uses: actions/checkout@v6 38 | - uses: dtolnay/rust-toolchain@stable 39 | - run: cargo check --all-targets 40 | - run: cargo test 41 | 42 | lint: 43 | runs-on: ubuntu-latest 44 | steps: 45 | - uses: actions/checkout@v6 46 | - uses: dtolnay/rust-toolchain@stable 47 | with: 48 | components: rustfmt, clippy 49 | - run: cargo fmt --all -- --check 50 | - run: cargo clippy --all-targets --all-features -- --deny warnings 51 | 52 | coverage: 53 | runs-on: ubuntu-latest 54 | environment: Coverage 55 | env: 56 | CARGO_TERM_COLOR: always 57 | steps: 58 | - uses: actions/checkout@v6 59 | - uses: dtolnay/rust-toolchain@stable 60 | - name: Install cargo-llvm-cov 61 | uses: taiki-e/install-action@cargo-llvm-cov 62 | - name: Generate code coverage 63 | run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info 64 | - name: Upload coverage to Codecov 65 | uses: codecov/codecov-action@v5 66 | with: 67 | token: ${{ secrets.CODECOV_TOKEN }} 68 | files: lcov.info 69 | fail_ci_if_error: true 70 | 71 | audit: 72 | runs-on: ubuntu-latest 73 | steps: 74 | - uses: actions/checkout@v6 75 | - uses: EmbarkStudios/cargo-deny-action@v2 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mendes: web toolkit for impatient perfectionists 2 | 3 | [![Documentation](https://docs.rs/mendes/badge.svg)](https://docs.rs/mendes/) 4 | [![Crates.io](https://img.shields.io/crates/v/mendes.svg)](https://crates.io/crates/mendes) 5 | [![Build status](https://github.com/djc/mendes/workflows/CI/badge.svg)](https://github.com/djc/mendes/actions?query=workflow%3ACI) 6 | [![Coverage status](https://codecov.io/gh/djc/mendes/branch/master/graph/badge.svg)](https://codecov.io/gh/djc/mendes) 7 | [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE-MIT) 8 | [![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE-APACHE) 9 | 10 | Mendes is a Rust web toolkit for impatient perfectionists (apologies to Django). 11 | It aims to be: 12 | 13 | * Modular: less framework, more library; pick and choose components 14 | * Async: async/await from the start 15 | * Low boilerplate: easy to get started, but with limited "magic" 16 | * Type-safe: leverage the type system to make error handling low effort 17 | * Secure: provide security by default; no unsafe code in this project 18 | * Run on stable Rust (no promises on MSRV though) 19 | 20 | Mendes is currently in an extremely early phase and probably not ready for anything 21 | but experiments for those who are curious. Feedback is always welcome though! 22 | 23 | ## Minimal example 24 | 25 | This should definitely become more minimal over time. 26 | 27 | ```rust 28 | use async_trait::async_trait; 29 | use hyper::Body; 30 | use mendes::application::IntoResponse; 31 | use mendes::http::request::Parts; 32 | use mendes::http::{Response, StatusCode}; 33 | use mendes::{handler, route, Application, Context}; 34 | 35 | #[handler(GET)] 36 | async fn hello(_: &App) -> Result, Error> { 37 | Ok(Response::builder() 38 | .status(StatusCode::OK) 39 | .body("Hello, world".into()) 40 | .unwrap()) 41 | } 42 | 43 | struct App {} 44 | 45 | #[async_trait] 46 | impl Application for App { 47 | type RequestBody = (); 48 | type ResponseBody = Body; 49 | type Error = Error; 50 | 51 | async fn handle(mut cx: Context) -> Response { 52 | route!(match cx.path() { 53 | _ => hello, 54 | }) 55 | } 56 | } 57 | 58 | #[derive(Debug)] 59 | enum Error { 60 | Mendes(mendes::Error), 61 | } 62 | 63 | impl From for Error { 64 | fn from(e: mendes::Error) -> Self { 65 | Error::Mendes(e) 66 | } 67 | } 68 | 69 | impl From<&Error> for StatusCode { 70 | fn from(e: &Error) -> StatusCode { 71 | let Error::Mendes(e) = e; 72 | StatusCode::from(e) 73 | } 74 | } 75 | 76 | impl IntoResponse for Error { 77 | fn into_response(self, _: &App, _: &Parts) -> Response { 78 | let Error::Mendes(err) = self; 79 | Response::builder() 80 | .status(StatusCode::from(&err)) 81 | .body(err.to_string().into()) 82 | .unwrap() 83 | } 84 | } 85 | ``` 86 | 87 | All feedback welcome. Feel free to file bugs, requests for documentation and 88 | any other feedback to the [issue tracker][issues]. 89 | 90 | Mendes was created and is maintained by Dirkjan Ochtman. 91 | 92 | [issues]: https://github.com/djc/mendes/issues 93 | -------------------------------------------------------------------------------- /mendes/src/key.rs: -------------------------------------------------------------------------------- 1 | use std::convert::TryInto; 2 | 3 | use data_encoding::HEXLOWER; 4 | use ring::rand::SecureRandom; 5 | use ring::{aead, rand}; 6 | use thiserror::Error; 7 | 8 | #[cfg(feature = "application")] 9 | use crate::application::Application; 10 | 11 | /// Give mendes-based APIs access to an AEAD key for the `Application` 12 | /// 13 | /// AEAD (Authenticated Encryption with Associated Data) encrypts data and authenticates 14 | /// it such that other parties cannot read or manipulate the encrypted data. Currently 15 | /// mendes uses this only to encrypt and authenticate cookie data. 16 | #[cfg(feature = "application")] 17 | #[cfg_attr(docsrs, doc(cfg(feature = "application")))] 18 | pub trait AppWithAeadKey: Application { 19 | fn key(&self) -> &Key; 20 | } 21 | 22 | /// An encryption key to authenticate and encrypt/decrypt cookie values 23 | /// 24 | /// This currently uses the ChaCha20-Poly1305 algorithm as defined in RFC 7539. 25 | pub struct Key(aead::LessSafeKey); 26 | 27 | impl Key { 28 | /// Create a new `Key` from the given secret key 29 | pub fn new(secret: &[u8; 32]) -> Self { 30 | Self(aead::LessSafeKey::new( 31 | aead::UnboundKey::new(&aead::CHACHA20_POLY1305, secret).unwrap(), 32 | )) 33 | } 34 | 35 | /// Create key from slice of hexadecimal characters 36 | /// 37 | /// This will fail if the length of the slice is not equal to 32. 38 | pub fn from_hex_lower(s: &[u8]) -> Result { 39 | let bytes = HEXLOWER 40 | .decode(s) 41 | .map_err(|_| Error::InvalidKeyCharacters)?; 42 | Ok(Self::new( 43 | (&*bytes).try_into().map_err(|_| Error::InvalidKeyLength)?, 44 | )) 45 | } 46 | 47 | pub fn decrypt<'a>(&self, aad: &[u8], input: &'a mut [u8]) -> Result<&'a [u8], Error> { 48 | if input.len() <= NONCE_LEN { 49 | return Err(Error::Decryption); 50 | } 51 | 52 | let ad = aead::Aad::from(aad); 53 | let (sealed, nonce) = input.split_at_mut(input.len() - NONCE_LEN); 54 | aead::Nonce::try_assume_unique_for_key(nonce) 55 | .and_then(move |nonce| self.0.open_in_place(nonce, ad, sealed)) 56 | .map(|plain| &*plain) 57 | .map_err(|_| Error::Decryption) 58 | } 59 | 60 | pub fn encrypt(&self, aad: &[u8], buf: &mut Vec) -> Result<(), Error> { 61 | let mut nonce_buf = [0; NONCE_LEN]; 62 | rand::SystemRandom::new() 63 | .fill(&mut nonce_buf) 64 | .map_err(|_| Error::GetRandomFailed)?; 65 | let nonce = aead::Nonce::try_assume_unique_for_key(&nonce_buf).unwrap(); // valid nonce length 66 | 67 | let aad = aead::Aad::from(aad); 68 | self.0.seal_in_place_append_tag(nonce, aad, buf).unwrap(); // unique nonce 69 | buf.extend(nonce_buf); 70 | Ok(()) 71 | } 72 | } 73 | 74 | #[derive(Debug, Error)] 75 | pub enum Error { 76 | #[error("failed to decrypt")] 77 | Decryption, 78 | #[error("failed to acquire random bytes for nonce")] 79 | GetRandomFailed, 80 | #[error("invalid key characters")] 81 | InvalidKeyCharacters, 82 | #[error("invalid key length")] 83 | InvalidKeyLength, 84 | } 85 | 86 | pub(crate) const NONCE_LEN: usize = 12; 87 | #[cfg(all(feature = "cookies", feature = "application"))] 88 | pub(crate) const TAG_LEN: usize = 16; 89 | -------------------------------------------------------------------------------- /mendes/tests/hyper.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "hyper")] 2 | 3 | use std::fmt::{self, Display}; 4 | use std::io; 5 | use std::net::SocketAddr; 6 | use std::time::Duration; 7 | 8 | use async_trait::async_trait; 9 | use bytes::Bytes; 10 | use mendes::application::IntoResponse; 11 | use mendes::http::request::Parts; 12 | use mendes::http::{Response, StatusCode}; 13 | use mendes::hyper::body::Incoming; 14 | use mendes::hyper::{ClientAddr, Server}; 15 | use mendes::{handler, route, Application, Body, Context}; 16 | use tokio::task::JoinHandle; 17 | use tokio::time::sleep; 18 | 19 | struct ServerRunner { 20 | handle: JoinHandle>, 21 | } 22 | 23 | impl ServerRunner { 24 | async fn run(addr: SocketAddr) -> Self { 25 | let handle = tokio::spawn(Server::bind(addr, App::default()).await.unwrap().serve()); 26 | sleep(Duration::from_millis(10)).await; 27 | Self { handle } 28 | } 29 | 30 | fn stop(self) { 31 | self.handle.abort(); 32 | } 33 | } 34 | 35 | impl Drop for ServerRunner { 36 | fn drop(&mut self) { 37 | self.handle.abort(); 38 | } 39 | } 40 | 41 | #[tokio::test] 42 | async fn test_client_addr() { 43 | let addr = "127.0.0.1:12345".parse::().unwrap(); 44 | let runner = ServerRunner::run(addr).await; 45 | 46 | let rsp = reqwest::get(format!("http://{addr}/client-addr")) 47 | .await 48 | .unwrap(); 49 | assert_eq!(rsp.status(), StatusCode::OK); 50 | 51 | let body = rsp.text().await.unwrap(); 52 | assert_eq!(body, "client_addr: 127.0.0.1"); 53 | 54 | runner.stop(); 55 | } 56 | 57 | #[derive(Default)] 58 | struct App {} 59 | 60 | #[async_trait] 61 | impl Application for App { 62 | type RequestBody = Incoming; 63 | type ResponseBody = Body; 64 | type Error = Error; 65 | 66 | async fn handle(mut cx: Context) -> Response { 67 | route!(match cx.path() { 68 | Some("client-addr") => client_addr, 69 | }) 70 | } 71 | } 72 | 73 | #[handler(GET)] 74 | async fn client_addr(_: &App, client_addr: ClientAddr) -> Result, Error> { 75 | Ok(Response::builder() 76 | .status(StatusCode::OK) 77 | .body(Body::from(Bytes::from(format!( 78 | "client_addr: {}", 79 | client_addr.ip() 80 | )))) 81 | .unwrap()) 82 | } 83 | 84 | #[derive(Debug)] 85 | enum Error { 86 | Mendes(mendes::Error), 87 | } 88 | 89 | impl std::error::Error for Error {} 90 | 91 | impl Display for Error { 92 | fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 93 | match self { 94 | Error::Mendes(err) => err.fmt(formatter), 95 | } 96 | } 97 | } 98 | 99 | impl From for Error { 100 | fn from(e: mendes::Error) -> Self { 101 | Error::Mendes(e) 102 | } 103 | } 104 | 105 | impl From<&Error> for StatusCode { 106 | fn from(e: &Error) -> StatusCode { 107 | let Error::Mendes(e) = e; 108 | StatusCode::from(e) 109 | } 110 | } 111 | 112 | impl IntoResponse for Error { 113 | fn into_response(self, _: &App, _: &Parts) -> Response { 114 | let Error::Mendes(err) = self; 115 | Response::builder() 116 | .status(StatusCode::from(&err)) 117 | .body(Body::from(Bytes::from(err.to_string()))) 118 | .unwrap() 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /mendes/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mendes" 3 | version = "0.9.4" 4 | edition = "2021" 5 | rust-version = "1.83" 6 | description = "Rust web toolkit for impatient perfectionists" 7 | documentation = "https://docs.rs/mendes" 8 | repository = "https://github.com/djc/mendes" 9 | keywords = ["web", "http", "server", "async"] 10 | categories = ["asynchronous", "web-programming::http-server"] 11 | license = "MIT OR Apache-2.0" 12 | workspace = ".." 13 | readme = "../README.md" 14 | 15 | [features] 16 | default = ["application"] 17 | application = ["http", "dep:async-trait", "dep:bytes", "dep:http-body", "dep:mendes-macros", "dep:percent-encoding", "dep:pin-project", "dep:serde", "dep:serde_urlencoded"] 18 | brotli = ["compression", "async-compression?/brotli"] 19 | chrono = ["dep:chrono"] 20 | compression = ["dep:async-compression", "dep:tokio", "dep:tokio-util"] 21 | cookies = ["http", "key", "dep:chrono", "dep:data-encoding", "dep:mendes-macros", "dep:postcard", "serde?/derive"] 22 | zlib = ["compression", "async-compression?/zlib"] 23 | deflate = ["zlib"] 24 | forms = ["dep:mendes-macros", "dep:serde_urlencoded", "serde?/derive"] 25 | gzip = ["compression", "async-compression?/gzip"] 26 | hyper = ["application", "http", "dep:async-trait", "dep:bytes", "dep:futures-util", "futures-util?/std", "dep:hyper", "dep:hyper-util", "dep:tokio", "tokio?/macros", "tracing"] 27 | key = ["dep:data-encoding", "dep:ring"] 28 | json = ["dep:serde_json"] 29 | uploads = ["http", "dep:httparse", "dep:memchr"] 30 | body = ["dep:http-body"] 31 | body-util = ["dep:http-body-util", "dep:bytes", "dep:http-body"] 32 | static = ["application", "http", "dep:mime_guess", "dep:tokio", "tokio?/fs"] 33 | tracing = ["dep:tracing"] 34 | 35 | [dependencies] 36 | async-compression = { version = "0.4.0", features = ["tokio"], optional = true } 37 | async-trait = { version = "0.1.24", optional = true } 38 | bytes = { version = "1", optional = true } 39 | chrono = { version = "0.4.23", optional = true, features = ["serde"] } 40 | data-encoding = { version = "2.1.2", optional = true } 41 | futures-util = { version = "0.3.7", optional = true, default-features = false } 42 | http = { version = "1", optional = true } 43 | http-body = { version = "1", optional = true } 44 | http-body-util = { version = "0.1", optional = true } 45 | httparse = { version = "1.3.4", optional = true } 46 | hyper = { version = "1", optional = true, features = ["http1", "http2", "server"] } 47 | hyper-util = { version = "0.1.3", features = ["http1", "http2", "server", "tokio"], optional = true } 48 | memchr = { version = "2.5", optional = true } 49 | mendes-macros = { version = "0.4", path = "../mendes-macros", optional = true } 50 | mime_guess = { version = "2.0.3", default-features = false, optional = true } 51 | percent-encoding = { version = "2.1.0", default-features = false, optional = true } 52 | pin-project = { version = "1.1.5", optional = true } 53 | postcard = { version = "1.0.6", default-features = false, features = ["use-std"], optional = true } 54 | ring = { version = "0.17.0", optional = true } 55 | serde = { version = "1.0.104", optional = true } 56 | serde_json = { version = "1.0.48", optional = true } 57 | serde_urlencoded = { version = "0.7.0", optional = true } 58 | thiserror = { version = "2.0.0" } 59 | tokio = { version = "1", optional = true } 60 | tokio-util = { version = "0.7", optional = true, features = ["codec", "compat", "io"] } 61 | tracing = { version = "0.1.26", optional = true } 62 | 63 | [dev-dependencies] 64 | serde = { version = "1.0.104", features = ["derive"] } 65 | reqwest = { version = "0.12", default-features = false } 66 | tokio = { version = "1", features = ["macros", "rt"] } 67 | 68 | [package.metadata.docs.rs] 69 | all-features = true 70 | rustdoc-args = ["--cfg", "docsrs"] 71 | -------------------------------------------------------------------------------- /mendes/tests/cookies.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "cookies")] 2 | 3 | use std::convert::TryInto; 4 | use std::sync::Arc; 5 | 6 | use async_trait::async_trait; 7 | use mendes::application::IntoResponse; 8 | use mendes::cookies::{cookie, AppWithAeadKey, AppWithCookies, Key}; 9 | use mendes::http::header::{COOKIE, SET_COOKIE}; 10 | use mendes::http::request::Parts; 11 | use mendes::http::{Request, Response, StatusCode}; 12 | use mendes::{handler, route, Application, Context}; 13 | use serde::{Deserialize, Serialize}; 14 | 15 | #[tokio::test] 16 | async fn cookie() { 17 | let app = Arc::new(App { 18 | key: mendes::cookies::Key::new(&[ 19 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 20 | 24, 25, 26, 27, 28, 29, 30, 31, 21 | ]), 22 | }); 23 | 24 | let rsp = App::handle(Context::new(app.clone(), path_request("/store"))).await; 25 | assert_eq!(rsp.status(), StatusCode::OK); 26 | let set = rsp.headers().get(SET_COOKIE).unwrap(); 27 | let value = set.to_str().unwrap().split(';').next().unwrap(); 28 | 29 | let mut req = path_request("/extract"); 30 | req.headers_mut().insert(COOKIE, value.try_into().unwrap()); 31 | let rsp = App::handle(Context::new(app, req)).await; 32 | assert_eq!(rsp.status(), StatusCode::OK); 33 | assert_eq!(rsp.into_body(), "user = 37"); 34 | } 35 | 36 | fn path_request(path: &str) -> Request<()> { 37 | Request::builder() 38 | .uri(format!("https://example.com{path}")) 39 | .body(()) 40 | .unwrap() 41 | } 42 | 43 | struct App { 44 | key: mendes::cookies::Key, 45 | } 46 | 47 | impl AppWithAeadKey for App { 48 | fn key(&self) -> &Key { 49 | &self.key 50 | } 51 | } 52 | 53 | #[async_trait] 54 | impl Application for App { 55 | type RequestBody = (); 56 | type ResponseBody = String; 57 | type Error = Error; 58 | 59 | async fn handle(mut cx: Context) -> Response { 60 | route!(match cx.path() { 61 | Some("store") => store, 62 | Some("extract") => extract, 63 | }) 64 | } 65 | } 66 | 67 | #[handler(GET)] 68 | async fn extract(app: &App, req: &http::request::Parts) -> Result, Error> { 69 | let session = app.cookie::(&req.headers).unwrap(); 70 | Ok(Response::builder() 71 | .status(StatusCode::OK) 72 | .body(format!("user = {}", session.user)) 73 | .unwrap()) 74 | } 75 | 76 | #[handler(GET)] 77 | async fn store(app: &App) -> Result, Error> { 78 | let session = Session { user: 37 }; 79 | Ok(Response::builder() 80 | .status(StatusCode::OK) 81 | .header(SET_COOKIE, app.set_cookie_header(Some(session)).unwrap()) 82 | .body("Hello, world".into()) 83 | .unwrap()) 84 | } 85 | 86 | #[cookie] 87 | #[derive(Deserialize, Serialize)] 88 | struct Session { 89 | user: i32, 90 | } 91 | 92 | #[derive(Debug)] 93 | enum Error { 94 | Mendes(mendes::Error), 95 | } 96 | 97 | impl From for Error { 98 | fn from(e: mendes::Error) -> Self { 99 | Error::Mendes(e) 100 | } 101 | } 102 | 103 | impl From<&Error> for StatusCode { 104 | fn from(e: &Error) -> StatusCode { 105 | let Error::Mendes(e) = e; 106 | StatusCode::from(e) 107 | } 108 | } 109 | 110 | impl IntoResponse for Error { 111 | fn into_response(self, _: &App, _: &Parts) -> Response { 112 | let Error::Mendes(err) = self; 113 | Response::builder() 114 | .status(StatusCode::from(&err)) 115 | .body(err.to_string()) 116 | .unwrap() 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /mendes-macros/src/lib.rs: -------------------------------------------------------------------------------- 1 | extern crate proc_macro; 2 | 3 | use proc_macro::TokenStream; 4 | use quote::{quote, ToTokens}; 5 | use syn::parse_macro_input; 6 | 7 | mod cookies; 8 | mod forms; 9 | mod route; 10 | 11 | #[proc_macro_attribute] 12 | pub fn cookie(meta: TokenStream, item: TokenStream) -> TokenStream { 13 | let ast = parse_macro_input!(item as syn::ItemStruct); 14 | let meta = parse_macro_input!(meta as cookies::CookieMeta); 15 | let cookie = cookies::cookie(&meta, &ast); 16 | let mut tokens = ast.to_token_stream(); 17 | tokens.extend(cookie); 18 | TokenStream::from(tokens) 19 | } 20 | 21 | #[proc_macro_attribute] 22 | pub fn form(meta: TokenStream, item: TokenStream) -> TokenStream { 23 | let mut ast = parse_macro_input!(item as syn::ItemStruct); 24 | let meta = parse_macro_input!(meta as forms::FormMeta); 25 | let display = forms::form(&meta, &mut ast); 26 | let mut tokens = ast.to_token_stream(); 27 | tokens.extend(display); 28 | TokenStream::from(tokens) 29 | } 30 | 31 | /// Implement a request handler wrapper for the annotated function 32 | /// 33 | /// The attribute takes allowed methods as its arguments: 34 | /// 35 | /// ```ignore 36 | /// /// This handler will immediately return a `405 Method not allowed` 37 | /// /// error for all request methods other than `GET` 38 | /// #[handler(GET)] 39 | /// fn hello(_: &App) -> Result, Error> { 40 | /// Ok(Response::builder() 41 | /// .status(StatusCode::OK) 42 | /// .body("Hello, world".into()) 43 | /// .unwrap()) 44 | /// } 45 | /// ``` 46 | /// 47 | /// The first argument of the function must be a reference to an implementer of 48 | /// the `Application` trait (the implementor may also be wrapped in an `Arc`). 49 | /// All unannotated arguments must be of types that implement the `FromContext` 50 | /// trait for the `Application` type used in the first argument. This includes 51 | /// `&http::request::Parts`, the `Request`'s headers and any number of types 52 | /// that can represent a path component from the URI: 53 | /// 54 | /// * `&[u8]` for the bytes representation of the path component 55 | /// * `Cow<'_, str>` 56 | /// * `String` 57 | /// * Numeric types (`i8`, `u8`, `i16`, `u16`, ..., `isize`, `usize`, `f32`, `f64`) 58 | /// * `bool` and `char` 59 | /// * If the `hyper` feature is enabled, `hyper::body::Body` 60 | /// (only if `Application::RequestBody` is also `Body`) 61 | /// 62 | /// Each of these types can be wrapped in `Option` for optional path components. 63 | /// Additionally, there are two attributes that may be used on handler arguments: 64 | /// 65 | /// * `#[rest]`: a `&str` representing the part of the request path not yet consumed by routing 66 | /// * `#[query]`: a type that implements `Deserialize`, and will be used to deserialize the URI query 67 | /// 68 | /// This macro will generate a module that contains a `call()` function mirroring 69 | /// the original function, and you may rely on this behavior (for example, for testing). 70 | /// 71 | /// ```ignore 72 | /// mod hello { 73 | /// use super::*; 74 | /// /// ... some internals hidden ... 75 | /// pub(super) async fn call(_: &App) -> Result, Error> { 76 | /// Ok(Response::builder() 77 | /// .status(StatusCode::OK) 78 | /// .body("Hello, world".into()) 79 | /// .unwrap()) 80 | /// } 81 | /// } 82 | /// ``` 83 | #[proc_macro_attribute] 84 | pub fn handler(meta: TokenStream, item: TokenStream) -> TokenStream { 85 | let ast = parse_macro_input!(item as syn::ItemFn); 86 | let methods = parse_macro_input!(meta as route::HandlerMethods).methods; 87 | route::handler(&methods, ast) 88 | } 89 | 90 | #[proc_macro_attribute] 91 | pub fn scope(_: TokenStream, item: TokenStream) -> TokenStream { 92 | let ast = parse_macro_input!(item as syn::ItemFn); 93 | route::scope(ast) 94 | } 95 | 96 | #[proc_macro] 97 | pub fn route(item: TokenStream) -> TokenStream { 98 | let mut ast = parse_macro_input!(item as syn::ExprMatch); 99 | route::route(&mut ast); 100 | quote!(#ast).into() 101 | } 102 | 103 | #[proc_macro_derive(ToField, attributes(option))] 104 | pub fn derive_to_field(item: TokenStream) -> TokenStream { 105 | let ast = parse_macro_input!(item as syn::DeriveInput); 106 | TokenStream::from(forms::to_field(ast)) 107 | } 108 | -------------------------------------------------------------------------------- /mendes/tests/body.rs: -------------------------------------------------------------------------------- 1 | #![cfg(all(feature = "application", feature = "hyper", feature = "body-util"))] 2 | 3 | use std::sync::Arc; 4 | 5 | #[cfg(all(feature = "compression", feature = "zlib"))] 6 | use async_compression::tokio::write::ZlibDecoder; 7 | use async_trait::async_trait; 8 | use http::header::{ACCEPT_ENCODING, CONTENT_TYPE}; 9 | use http_body_util::BodyExt; 10 | #[cfg(all(feature = "compression", feature = "zlib"))] 11 | use tokio::io::AsyncWriteExt; 12 | 13 | use mendes::application::IntoResponse; 14 | #[cfg(feature = "compression")] 15 | use mendes::body::EncodeResponse; 16 | use mendes::http::request::Parts; 17 | use mendes::http::{Method, Request, Response, StatusCode}; 18 | use mendes::{handler, route, Application, Body, Context}; 19 | 20 | #[cfg(feature = "json")] 21 | #[tokio::test] 22 | async fn test_json_decode() { 23 | let rsp = handle(path_request("/sum", "[1, 2, 3]", None)).await; 24 | assert_eq!(rsp.status(), StatusCode::OK); 25 | let body = rsp.into_body().collect().await.unwrap().to_bytes(); 26 | assert_eq!(String::from_utf8_lossy(&body), "6"); 27 | } 28 | 29 | #[cfg(all(feature = "compression", feature = "zlib"))] 30 | #[tokio::test] 31 | async fn test_deflate_compression() { 32 | let rsp = handle(path_request("/echo", "hello world", Some("deflate"))).await; 33 | assert_eq!(rsp.status(), StatusCode::OK); 34 | let body = rsp.into_body().collect().await.unwrap().to_bytes(); 35 | // If the lower half of the first byte is 0x08, then the stream is 36 | // a zlib stream, otherwise it's a 37 | // raw deflate stream. 38 | assert_eq!(body[0] & 0x0F, 0x8); 39 | 40 | // Decode as Zlib container 41 | let mut decoder = ZlibDecoder::new(Vec::new()); 42 | decoder.write_all(&body).await.unwrap(); 43 | decoder.shutdown().await.unwrap(); 44 | assert_eq!( 45 | String::from_utf8_lossy(&decoder.into_inner()), 46 | "hello world" 47 | ); 48 | } 49 | 50 | fn path_request(path: &str, body: &str, compression: Option<&'static str>) -> Request { 51 | let mut request = Request::builder() 52 | .method(Method::POST) 53 | .uri(format!("https://example.com{path}")) 54 | .header(CONTENT_TYPE, "application/json; charset=utf-8"); 55 | if let Some(compression) = compression { 56 | request = request.header(ACCEPT_ENCODING, compression); 57 | } 58 | request.body(body.to_owned().into()).unwrap() 59 | } 60 | 61 | async fn handle(req: Request) -> Response { 62 | App::handle(Context::new(Arc::new(App {}), req)).await 63 | } 64 | 65 | struct App {} 66 | 67 | #[async_trait] 68 | impl Application for App { 69 | type RequestBody = Body; 70 | type ResponseBody = Body; 71 | type Error = Error; 72 | 73 | async fn handle(mut cx: Context) -> Response { 74 | let response = route!(match cx.path() { 75 | #[cfg(feature = "json")] 76 | Some("sum") => sum, 77 | Some("echo") => echo, 78 | }); 79 | 80 | #[cfg(feature = "compression")] 81 | let response = response.encoded(&cx.req); 82 | 83 | response 84 | } 85 | } 86 | 87 | #[cfg(feature = "json")] 88 | #[handler(POST)] 89 | async fn sum(_: &App, req: &Parts, body: Body) -> Result, Error> { 90 | let numbers = App::from_body::>(req, body, 16).await.unwrap(); 91 | Ok(Response::builder() 92 | .body(numbers.iter().sum::().to_string().into()) 93 | .unwrap()) 94 | } 95 | 96 | #[handler(POST)] 97 | async fn echo(_: &App, _req: &Parts, body: Body) -> Result, Error> { 98 | let content = App::body_bytes(body, 100).await.unwrap(); 99 | Ok(Response::builder().body(content.into()).unwrap()) 100 | } 101 | 102 | #[derive(Debug)] 103 | enum Error { 104 | Mendes(mendes::Error), 105 | } 106 | 107 | impl From for Error { 108 | fn from(e: mendes::Error) -> Self { 109 | Error::Mendes(e) 110 | } 111 | } 112 | 113 | impl From<&Error> for StatusCode { 114 | fn from(e: &Error) -> StatusCode { 115 | let Error::Mendes(e) = e; 116 | StatusCode::from(e) 117 | } 118 | } 119 | 120 | impl IntoResponse for Error { 121 | fn into_response(self, _: &App, _: &Parts) -> Response { 122 | let Error::Mendes(err) = self; 123 | Response::builder() 124 | .status(StatusCode::from(&err)) 125 | .body(err.to_string().into()) 126 | .unwrap() 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /mendes-macros/src/cookies.rs: -------------------------------------------------------------------------------- 1 | use proc_macro2::Span; 2 | use quote::{format_ident, quote, ToTokens}; 3 | use syn::parse::{Parse, ParseStream}; 4 | use syn::punctuated::Punctuated; 5 | use syn::token::Comma; 6 | 7 | /// Derive the `mendes::cookies::CookieData` trait For the given struct 8 | /// 9 | /// Defaults to an expiry time of 6 hours. 10 | pub fn cookie(meta: &CookieMeta, ast: &syn::ItemStruct) -> proc_macro2::TokenStream { 11 | let ident = &ast.ident; 12 | let name = syn::LitStr::new(&ident.to_string(), Span::call_site()); 13 | 14 | let (http_only, max_age, path, secure) = 15 | (meta.http_only, meta.max_age, &meta.path, meta.secure); 16 | let domain = match &meta.domain { 17 | Some(v) => quote!(Some(#v)), 18 | None => quote!(None), 19 | }; 20 | let same_site = match &meta.same_site { 21 | Some(v) => { 22 | let variant = format_ident!("{}", v); 23 | quote!(Some(mendes::cookies::SameSite::#variant)) 24 | } 25 | None => quote!(None), 26 | }; 27 | 28 | quote!( 29 | impl mendes::cookies::CookieData for #ident { 30 | fn meta() -> mendes::cookies::CookieMeta<'static> { 31 | mendes::cookies::CookieMeta { 32 | domain: #domain, 33 | http_only: #http_only, 34 | max_age: #max_age, 35 | path: #path, 36 | same_site: #same_site, 37 | secure: #secure, 38 | } 39 | } 40 | 41 | const NAME: &'static str = #name; 42 | } 43 | ) 44 | } 45 | 46 | pub struct CookieMeta { 47 | domain: Option, 48 | http_only: bool, 49 | max_age: u32, 50 | path: String, 51 | same_site: Option, 52 | secure: bool, 53 | } 54 | 55 | impl Parse for CookieMeta { 56 | fn parse(input: ParseStream) -> syn::Result { 57 | let mut new = CookieMeta::default(); 58 | for field in Punctuated::::parse_terminated(input)? { 59 | let value = match field.value { 60 | syn::Expr::Lit(v) => v, 61 | _ => panic!( 62 | "expected literal value for key {:?}", 63 | field.path.to_token_stream() 64 | ), 65 | }; 66 | 67 | if field.path.is_ident("domain") { 68 | match value.lit { 69 | syn::Lit::Str(v) => new.domain = Some(v.value()), 70 | _ => panic!("expected string value for key 'domain'"), 71 | } 72 | } else if field.path.is_ident("http_only") { 73 | match value.lit { 74 | syn::Lit::Bool(v) => { 75 | new.http_only = v.value(); 76 | } 77 | _ => panic!("expected string value for key 'http_only'"), 78 | } 79 | } else if field.path.is_ident("max_age") { 80 | match value.lit { 81 | syn::Lit::Int(v) => { 82 | new.max_age = v 83 | .base10_parse::() 84 | .expect("expected u32 value for key 'max_age'"); 85 | } 86 | _ => panic!("expected string value for key 'max_age'"), 87 | } 88 | } else if field.path.is_ident("path") { 89 | match value.lit { 90 | syn::Lit::Str(v) => new.path = v.value(), 91 | _ => panic!("expected string value for key 'path'"), 92 | } 93 | } else if field.path.is_ident("same_site") { 94 | match value.lit { 95 | syn::Lit::Str(v) => { 96 | let value = v.value(); 97 | new.same_site = Some(match value.as_str() { 98 | "Strict" => value, 99 | "Lax" => value, 100 | "None" => value, 101 | _ => panic!("expected 'Strict', 'Lax' or 'None' for key 'same_site'"), 102 | }); 103 | } 104 | _ => panic!("expected string value for key 'same_site'"), 105 | } 106 | } else if field.path.is_ident("secure") { 107 | match value.lit { 108 | syn::Lit::Bool(v) => { 109 | new.secure = v.value(); 110 | } 111 | _ => panic!("expected string value for key 'secure'"), 112 | } 113 | } else { 114 | panic!("unexpected key {:?}", field.path.to_token_stream()); 115 | } 116 | } 117 | 118 | if new.same_site.as_deref() == Some("Strict") && !new.secure { 119 | panic!("'same_site' is 'Strict' but 'secure' is false"); 120 | } 121 | 122 | Ok(new) 123 | } 124 | } 125 | 126 | impl Default for CookieMeta { 127 | fn default() -> Self { 128 | Self { 129 | domain: None, 130 | http_only: false, 131 | max_age: 6 * 60 * 60, 132 | path: "/".to_owned(), 133 | same_site: Some("None".to_owned()), 134 | secure: true, 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /mendes/tests/basic.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "application")] 2 | 3 | use std::borrow::Cow; 4 | use std::sync::Arc; 5 | 6 | use async_trait::async_trait; 7 | use mendes::application::{IntoResponse, PathState}; 8 | use mendes::http::request::Parts; 9 | use mendes::http::{Method, Request, Response, StatusCode}; 10 | use mendes::{handler, route, scope, Application, Context, FromContext}; 11 | 12 | #[tokio::test] 13 | async fn test_query() { 14 | let rsp = handle(path_request("/query?foo=3&bar=baz")).await; 15 | assert_eq!(rsp.status(), StatusCode::OK); 16 | assert_eq!(rsp.into_body(), "query: Query { foo: 3, bar: \"baz\" }"); 17 | } 18 | 19 | #[tokio::test] 20 | async fn test_method_get() { 21 | let rsp = handle(path_request("/method")).await; 22 | assert_eq!(rsp.status(), StatusCode::OK); 23 | assert_eq!(rsp.into_body(), "Hello, world"); 24 | } 25 | 26 | #[tokio::test] 27 | async fn test_method_post() { 28 | let mut req = path_request("/method/post"); 29 | *req.method_mut() = Method::POST; 30 | let rsp = handle(req).await; 31 | assert_eq!(rsp.status(), StatusCode::OK); 32 | assert_eq!(rsp.into_body(), "Hello, post"); 33 | } 34 | 35 | #[tokio::test] 36 | async fn test_magic_405() { 37 | let mut req = path_request("/method/post"); 38 | *req.method_mut() = Method::PATCH; 39 | let rsp = handle(req).await; 40 | assert_eq!(rsp.status(), StatusCode::METHOD_NOT_ALLOWED); 41 | } 42 | 43 | #[tokio::test] 44 | async fn test_nested_rest() { 45 | let rsp = handle(path_request("/nested/some/more")).await; 46 | assert_eq!(rsp.status(), StatusCode::OK); 47 | assert_eq!(rsp.into_body(), "nested rest some/more"); 48 | } 49 | 50 | #[tokio::test] 51 | async fn test_nested_right() { 52 | let rsp = handle(path_request("/nested/right/2018")).await; 53 | assert_eq!(rsp.status(), StatusCode::OK); 54 | assert_eq!(rsp.into_body(), "nested right 2018"); 55 | } 56 | 57 | #[tokio::test] 58 | async fn test_inc_invalid() { 59 | let rsp = handle(path_request("/inc/Foo")).await; 60 | assert_eq!(rsp.status(), StatusCode::NOT_FOUND); 61 | assert_eq!(rsp.into_body(), "unable to parse path component"); 62 | } 63 | 64 | #[tokio::test] 65 | async fn test_inc() { 66 | let rsp = handle(path_request("/inc/2016")).await; 67 | assert_eq!(rsp.status(), StatusCode::OK); 68 | assert_eq!(rsp.into_body(), "num = 2017"); 69 | } 70 | 71 | #[tokio::test] 72 | async fn test_named() { 73 | let rsp = handle(path_request("/named/Foo")).await; 74 | assert_eq!(rsp.status(), StatusCode::OK); 75 | assert_eq!(rsp.into_body(), "Hello, Foo"); 76 | } 77 | 78 | #[tokio::test] 79 | async fn test_named_no_arg() { 80 | let rsp = handle(path_request("/named")).await; 81 | assert_eq!(rsp.status(), StatusCode::NOT_FOUND); 82 | assert_eq!(rsp.into_body(), "missing path component"); 83 | } 84 | 85 | #[tokio::test] 86 | async fn test_magic_404() { 87 | let rsp = handle(path_request("/foo")).await; 88 | assert_eq!(rsp.status(), StatusCode::NOT_FOUND); 89 | assert_eq!(rsp.into_body(), "no matching routes"); 90 | } 91 | 92 | #[tokio::test] 93 | async fn test_custom_error_handler() { 94 | let rsp = handle(path_request("/custom_hello/true")).await; 95 | assert_eq!(rsp.status(), StatusCode::IM_USED); 96 | 97 | // This is missing the ContextExtraction path part. In Error the resulting error is serialized as StatusCode::OK. 98 | // But we want to test the custom error handler which serializes this into StatusCode::IM_A_TEAPOT. 99 | let rsp = handle(path_request("/custom_hello")).await; 100 | assert_eq!(rsp.status(), StatusCode::IM_A_TEAPOT); 101 | } 102 | 103 | #[tokio::test] 104 | async fn basic() { 105 | let rsp = handle(path_request("/hello")).await; 106 | assert_eq!(rsp.status(), StatusCode::OK); 107 | } 108 | 109 | fn path_request(path: &str) -> Request<()> { 110 | Request::builder() 111 | .uri(format!("https://example.com{path}")) 112 | .body(()) 113 | .unwrap() 114 | } 115 | 116 | async fn handle(req: Request<()>) -> Response { 117 | App::handle(Context::new(Arc::new(App {}), req)).await 118 | } 119 | 120 | struct App {} 121 | 122 | #[async_trait] 123 | impl Application for App { 124 | type RequestBody = (); 125 | type ResponseBody = String; 126 | type Error = Error; 127 | 128 | async fn handle(mut cx: Context) -> Response { 129 | route!(match cx.path() { 130 | Some("hello") => hello, 131 | Some("named") => named, 132 | Some("inc") => inc, 133 | Some("nested") => match cx.path() { 134 | Some("right") => nested_right, 135 | _ => nested_rest, 136 | }, 137 | Some("scoped") => scoped, 138 | Some("method") => match cx.method() { 139 | GET => hello, 140 | POST => named, 141 | }, 142 | Some("custom_hello") => custom_error, 143 | 144 | Some("query") => with_query, 145 | }) 146 | } 147 | } 148 | 149 | #[scope] 150 | async fn scoped(cx: &mut Context) -> Response { 151 | route!(match cx.path() { 152 | Some("right") => nested_right, 153 | _ => nested_rest, 154 | }) 155 | } 156 | 157 | #[handler(GET)] 158 | async fn with_query(_: &App, #[query] query: Query<'_>) -> Result, Error> { 159 | Ok(Response::builder() 160 | .status(StatusCode::OK) 161 | .body(format!("query: {query:?}")) 162 | .unwrap()) 163 | } 164 | 165 | #[derive(Debug, serde::Deserialize)] 166 | #[allow(dead_code)] // Reflected as part of the `Debug` impl 167 | struct Query<'a> { 168 | foo: usize, 169 | bar: Cow<'a, str>, 170 | } 171 | 172 | #[handler(GET)] 173 | async fn nested_rest(_: &App, #[rest] path: Cow<'_, str>) -> Result, Error> { 174 | Ok(Response::builder() 175 | .status(StatusCode::OK) 176 | .body(format!("nested rest {path}")) 177 | .unwrap()) 178 | } 179 | 180 | #[handler(GET)] 181 | async fn nested_right(_: &App, num: usize) -> Result, Error> { 182 | Ok(Response::builder() 183 | .status(StatusCode::OK) 184 | .body(format!("nested right {num}")) 185 | .unwrap()) 186 | } 187 | 188 | #[handler(GET)] // use mutable argument to test this case 189 | async fn inc(_: &App, mut num: usize) -> Result, Error> { 190 | num += 1; 191 | Ok(Response::builder() 192 | .status(StatusCode::OK) 193 | .body(format!("num = {num}")) 194 | .unwrap()) 195 | } 196 | 197 | #[handler(get, post)] 198 | async fn named(_: &App, name: String) -> Result, Error> { 199 | Ok(Response::builder() 200 | .status(StatusCode::OK) 201 | .body(format!("Hello, {name}")) 202 | .unwrap()) 203 | } 204 | 205 | #[handler(GET)] 206 | async fn hello(_: &App) -> Result, Error> { 207 | Ok(Response::builder() 208 | .status(StatusCode::OK) 209 | .body("Hello, world".into()) 210 | .unwrap()) 211 | } 212 | 213 | #[handler(GET)] 214 | async fn custom_error(_: &App, _x: ContextExtraction) -> Result, HandlerError> { 215 | Err(HandlerError::Test) 216 | } 217 | 218 | #[derive(Debug)] 219 | enum Error { 220 | Mendes(mendes::Error), 221 | NotTrue, 222 | } 223 | 224 | impl From for Error { 225 | fn from(e: mendes::Error) -> Self { 226 | Error::Mendes(e) 227 | } 228 | } 229 | 230 | impl From<&Error> for StatusCode { 231 | fn from(e: &Error) -> StatusCode { 232 | match e { 233 | Error::Mendes(e) => StatusCode::from(e), 234 | Error::NotTrue => StatusCode::OK, 235 | } 236 | } 237 | } 238 | 239 | impl IntoResponse for Error { 240 | fn into_response(self, _: &App, _: &Parts) -> Response { 241 | let builder = Response::builder().status(StatusCode::from(&self)); 242 | match self { 243 | Error::Mendes(err) => builder.body(err.to_string()), 244 | Error::NotTrue => builder.body("".to_string()), 245 | } 246 | .unwrap() 247 | } 248 | } 249 | 250 | enum HandlerError { 251 | Mendes(mendes::Error), 252 | NotTrue, 253 | Test, 254 | } 255 | 256 | impl From for HandlerError { 257 | fn from(e: mendes::Error) -> Self { 258 | Self::Mendes(e) 259 | } 260 | } 261 | 262 | impl From for HandlerError { 263 | fn from(e: Error) -> Self { 264 | match e { 265 | Error::Mendes(e) => HandlerError::Mendes(e), 266 | Error::NotTrue => HandlerError::NotTrue, 267 | } 268 | } 269 | } 270 | 271 | impl IntoResponse for HandlerError { 272 | fn into_response(self, _: &App, _: &Parts) -> Response { 273 | let builder = Response::builder(); 274 | match self { 275 | HandlerError::Mendes(err) => { 276 | builder.status(StatusCode::from(&err)).body(err.to_string()) 277 | } 278 | HandlerError::Test => builder.status(StatusCode::IM_USED).body("".to_string()), 279 | HandlerError::NotTrue => builder.status(StatusCode::IM_A_TEAPOT).body("".to_string()), 280 | } 281 | .unwrap() 282 | } 283 | } 284 | 285 | struct ContextExtraction; 286 | 287 | impl FromContext<'_, App> for ContextExtraction { 288 | fn from_context( 289 | _app: &'_ Arc, 290 | req: &'_ Parts, 291 | state: &mut PathState, 292 | _: &mut Option<::RequestBody>, 293 | ) -> Result { 294 | match state.next(req.uri.path()) { 295 | Some("true") => Ok(ContextExtraction), 296 | _ => Err(Error::NotTrue), 297 | } 298 | } 299 | } 300 | -------------------------------------------------------------------------------- /mendes-macros/src/route.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Display; 2 | 3 | use proc_macro::TokenStream; 4 | use proc_macro2::{Ident, Span}; 5 | use quote::quote; 6 | use syn::parse::{Parse, ParseStream}; 7 | use syn::parse_quote; 8 | use syn::punctuated::Punctuated; 9 | use syn::token::Comma; 10 | 11 | pub fn handler(methods: &[T], mut ast: syn::ItemFn) -> TokenStream 12 | where 13 | T: Display, 14 | { 15 | let app_type = match ast.sig.inputs.first() { 16 | Some(syn::FnArg::Typed(syn::PatType { ty, .. })) => match **ty { 17 | syn::Type::Reference(ref reffed) => (*reffed.elem).clone(), 18 | _ => panic!("handler's first argument must be a reference"), 19 | }, 20 | _ => panic!("handler argument lists must have &App as their first type"), 21 | }; 22 | 23 | let app_type = match &app_type { 24 | syn::Type::Path(syn::TypePath { path, .. }) => Some(path), 25 | _ => None, 26 | } 27 | .and_then(|path| path.segments.first()) 28 | .and_then(|segment| match &segment.ident { 29 | id if id == "Arc" => Some(&segment.arguments), 30 | _ => None, 31 | }) 32 | .and_then(|args| match args { 33 | syn::PathArguments::AngleBracketed(inner) => Some(inner), 34 | _ => None, 35 | }) 36 | .and_then(|args| match args.args.first() { 37 | Some(syn::GenericArgument::Type(ty)) => Some(ty.clone()), 38 | _ => None, 39 | }) 40 | .unwrap_or(app_type); 41 | 42 | let mut method_patterns = proc_macro2::TokenStream::new(); 43 | for (i, method) in methods.iter().enumerate() { 44 | let method = Ident::new(&method.to_string().to_ascii_uppercase(), Span::call_site()); 45 | method_patterns.extend(if i > 0 { 46 | quote!( | &mendes::http::Method::#method) 47 | } else { 48 | quote!(&mendes::http::Method::#method) 49 | }); 50 | } 51 | 52 | let mut done = false; 53 | let mut prefix = proc_macro2::TokenStream::new(); 54 | let mut args = proc_macro2::TokenStream::new(); 55 | for (i, arg) in ast.sig.inputs.iter_mut().enumerate() { 56 | let typed = match arg { 57 | syn::FnArg::Typed(typed) => typed, 58 | _ => panic!("did not expect receiver argument in handler"), 59 | }; 60 | 61 | let mut special = false; 62 | let (pat, ty) = (&*typed.pat, &typed.ty); 63 | let name = match pat { 64 | syn::Pat::Wild(_) => Ident::new(&format!("_{i}"), Span::call_site()), 65 | syn::Pat::Ident(pat) => pat.ident.clone(), 66 | _ => panic!("only identifiers and wildcards allowed in handler argument list"), 67 | }; 68 | 69 | typed.attrs.retain(|attr| { 70 | if attr.path().is_ident("rest") { 71 | prefix.extend(quote!( 72 | let #pat = as mendes::FromContext<#app_type>>::from_context( 73 | &cx.app, &cx.req, &mut cx.path, &mut cx.body, 74 | )?.0; 75 | )); 76 | args.extend(quote!(#name,)); 77 | done = true; 78 | special = true; 79 | false 80 | } else if attr.path().is_ident("query") { 81 | prefix.extend(quote!( 82 | let #pat = as mendes::FromContext<#app_type>>::from_context( 83 | &cx.app, &cx.req, &mut cx.path, &mut cx.body, 84 | )?.0; 85 | )); 86 | args.extend(quote!(#name,)); 87 | special = true; 88 | false 89 | } else { 90 | true 91 | } 92 | }); 93 | 94 | if special { 95 | continue; 96 | } else if done { 97 | panic!("more arguments after #[rest] not allowed"); 98 | } 99 | 100 | prefix.extend(quote!( 101 | let #name = <#ty as mendes::FromContext<#app_type>>::from_context( 102 | &cx.app, &cx.req, &mut cx.path, &mut cx.body, 103 | )?; 104 | )); 105 | args.extend(quote!(#name,)); 106 | } 107 | 108 | let name = ast.sig.ident.clone(); 109 | let orig_vis = ast.vis.clone(); 110 | ast.vis = nested_visibility(ast.vis); 111 | 112 | let handler = { 113 | let nested_vis = &ast.vis; 114 | let generics = &ast.sig.generics; 115 | let rtype = &ast.sig.output; 116 | let where_clause = &ast.sig.generics.where_clause; 117 | quote!( 118 | #nested_vis async fn handler #generics( 119 | cx: &mut mendes::application::Context<#app_type> 120 | ) #rtype #where_clause { 121 | match &cx.req.method { 122 | #method_patterns => {} 123 | _ => return Err(mendes::Error::MethodNotAllowed.into()), 124 | } 125 | #prefix 126 | call(#args).await 127 | } 128 | ) 129 | }; 130 | 131 | let call = { 132 | ast.sig.ident = Ident::new("call", Span::call_site()); 133 | quote!(#ast) 134 | }; 135 | 136 | quote!(#orig_vis mod #name { 137 | use super::*; 138 | #handler 139 | #call 140 | }) 141 | .into() 142 | } 143 | 144 | fn nested_visibility(vis: syn::Visibility) -> syn::Visibility { 145 | match vis { 146 | cur @ syn::Visibility::Public(_) => cur, 147 | syn::Visibility::Inherited => visibility("super"), 148 | cur @ syn::Visibility::Restricted(_) => { 149 | let inner = match &cur { 150 | syn::Visibility::Restricted(inner) => inner, 151 | _ => unreachable!(), 152 | }; 153 | 154 | if inner.path.is_ident("crate") { 155 | visibility("crate") 156 | } else if inner.path.is_ident("self") { 157 | visibility("super") 158 | } else if inner.path.is_ident("super") { 159 | visibility("super::super") 160 | } else { 161 | cur 162 | } 163 | } 164 | } 165 | } 166 | 167 | fn visibility(path: &str) -> syn::Visibility { 168 | let path = Ident::new(path, Span::call_site()); 169 | parse_quote!(pub(in #path)) 170 | } 171 | 172 | pub fn scope(mut ast: syn::ItemFn) -> TokenStream { 173 | let orig_ident = ast.sig.ident.clone(); 174 | let orig_vis = ast.vis.clone(); 175 | 176 | ast.vis = nested_visibility(ast.vis); 177 | ast.sig.ident = Ident::new("handler", Span::call_site()); 178 | 179 | quote!(#orig_vis mod #orig_ident { 180 | use super::*; 181 | #ast 182 | }) 183 | .into() 184 | } 185 | 186 | pub fn route(ast: &mut syn::ExprMatch) { 187 | let (cx, ty) = match &*ast.expr { 188 | syn::Expr::MethodCall(call) => { 189 | let ty = match &call.method { 190 | id if id == "path" => RouteType::Path, 191 | id if id == "method" => RouteType::Method, 192 | m => panic!("unroutable method {m:?}"), 193 | }; 194 | 195 | let cx = match &*call.receiver { 196 | syn::Expr::Path(p) if p.path.get_ident().is_some() => { 197 | p.path.get_ident().unwrap().clone() 198 | } 199 | _ => panic!("inner expression must method call on identifier"), 200 | }; 201 | 202 | match ty { 203 | RouteType::Path => { 204 | let expr = &*ast.expr; 205 | *ast.expr = parse_quote!(#expr.as_deref()); 206 | } 207 | RouteType::Method => { 208 | let expr = &*ast.expr; 209 | *ast.expr = parse_quote!(*#expr); 210 | } 211 | } 212 | 213 | (cx, ty) 214 | } 215 | _ => panic!("expected method call in match expression"), 216 | }; 217 | 218 | let mut wildcard = false; 219 | for arm in ast.arms.iter_mut() { 220 | let mut rewind = false; 221 | if let syn::Pat::Wild(_) = arm.pat { 222 | wildcard = true; 223 | rewind = true; 224 | } 225 | 226 | if let RouteType::Method = ty { 227 | match &mut arm.pat { 228 | syn::Pat::Ident(method) => { 229 | arm.pat = parse_quote!(mendes::http::Method::#method); 230 | } 231 | _ => panic!("method pattern must be an identifier"), 232 | } 233 | } 234 | 235 | match &mut *arm.body { 236 | syn::Expr::Path(path) => { 237 | let rewind = rewind.then(|| quote!(#cx.rewind();)); 238 | *arm.body = parse_quote!({ 239 | #rewind 240 | let rsp = #path::handler(#cx.as_mut()).await; 241 | ::mendes::application::IntoResponse::into_response(rsp, &*#cx.app, &cx.req) 242 | }); 243 | } 244 | syn::Expr::Match(inner) => route(inner), 245 | _ => panic!("only identifiers, paths and match expressions allowed"), 246 | } 247 | } 248 | 249 | if !wildcard { 250 | let variant = match ty { 251 | RouteType::Path => quote!(PathNotFound), 252 | RouteType::Method => quote!(MethodNotAllowed), 253 | }; 254 | 255 | ast.arms.push(parse_quote!( 256 | _ => { 257 | let e = ::mendes::Error::#variant; 258 | ::mendes::application::IntoResponse::into_response(e, &*#cx.app, &cx.req) 259 | } 260 | )); 261 | } 262 | } 263 | 264 | enum RouteType { 265 | Path, 266 | Method, 267 | } 268 | 269 | pub struct HandlerMethods { 270 | pub methods: Vec, 271 | } 272 | 273 | impl Parse for HandlerMethods { 274 | fn parse(input: ParseStream) -> syn::Result { 275 | let methods = Punctuated::::parse_terminated(input)?; 276 | Ok(Self { 277 | methods: methods.into_iter().collect(), 278 | }) 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /mendes/src/hyper.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::convert::Infallible; 3 | use std::error::Error as StdError; 4 | use std::future::{pending, Future, Pending}; 5 | use std::io; 6 | use std::marker::Send; 7 | use std::net::SocketAddr; 8 | use std::panic::AssertUnwindSafe; 9 | use std::pin::{pin, Pin}; 10 | use std::sync::Arc; 11 | use std::time::Duration; 12 | 13 | use futures_util::future::{CatchUnwind, FutureExt, Map}; 14 | use http::request::Parts; 15 | use http::{Request, Response, StatusCode}; 16 | use hyper::body::{Body, Incoming}; 17 | use hyper::service::Service; 18 | use hyper_util::rt::{TokioExecutor, TokioIo}; 19 | use hyper_util::server::conn::auto::Builder; 20 | use tokio::net::{TcpListener, TcpStream}; 21 | use tokio::sync::watch; 22 | use tokio::time::sleep; 23 | use tracing::{debug, error, info}; 24 | 25 | use super::Application; 26 | use crate::application::{Context, FromContext, PathState}; 27 | 28 | pub use hyper::body; 29 | 30 | pub struct Server { 31 | listener: TcpListener, 32 | app: Arc, 33 | signal: Option, 34 | } 35 | 36 | impl Server> { 37 | pub async fn bind(address: SocketAddr, app: A) -> Result>, io::Error> { 38 | Ok(Self::new(TcpListener::bind(address).await?, app)) 39 | } 40 | 41 | pub fn new(listener: TcpListener, app: A) -> Server> { 42 | Server { 43 | listener, 44 | app: Arc::new(app), 45 | signal: None, 46 | } 47 | } 48 | } 49 | 50 | impl Server> { 51 | pub fn with_graceful_shutdown>(self, signal: F) -> Server { 52 | let Server { listener, app, .. } = self; 53 | Server { 54 | listener, 55 | app, 56 | signal: Some(signal), 57 | } 58 | } 59 | } 60 | 61 | impl Server 62 | where 63 | A: Application + Sync + 'static, 64 | A::RequestBody: From, 65 | <::ResponseBody as Body>::Data: Send, 66 | <::ResponseBody as Body>::Error: StdError + Send + Sync, 67 | ::ResponseBody: From<&'static str> + Send, 68 | F: Future + Send + 'static, 69 | { 70 | pub async fn serve(self) -> Result<(), io::Error> { 71 | let Server { 72 | listener, 73 | app, 74 | signal, 75 | } = self; 76 | 77 | let (listener_state, conn_state) = states(signal); 78 | let mut shutting_down = pin!(async move { 79 | match listener_state.shutting_down { 80 | Some(shutting_down) => shutting_down.closed().await, 81 | None => pending().await, 82 | } 83 | } 84 | .fuse()); 85 | 86 | loop { 87 | let (stream, addr) = tokio::select! { 88 | res = listener.accept() => { 89 | match res { 90 | Ok((stream, addr)) => (stream, addr), 91 | Err(error) => { 92 | use io::ErrorKind::*; 93 | if matches!(error.kind(), ConnectionRefused | ConnectionAborted | ConnectionReset) { 94 | continue; 95 | } 96 | 97 | // Sleep for a bit to see if the error clears 98 | error!(%error, "error accepting connection"); 99 | sleep(Duration::from_secs(1)).await; 100 | continue; 101 | } 102 | } 103 | } 104 | _ = shutting_down.as_mut() => break, 105 | }; 106 | 107 | debug!("connection accepted from {addr}"); 108 | tokio::spawn( 109 | Connection { 110 | stream, 111 | addr, 112 | state: conn_state.clone(), 113 | app: app.clone(), 114 | } 115 | .run(), 116 | ); 117 | } 118 | 119 | let ListenerState { task_monitor, .. } = listener_state; 120 | drop(conn_state); 121 | drop(listener); 122 | if let Some(task_monitor) = task_monitor { 123 | let tasks = task_monitor.receiver_count(); 124 | if tasks > 0 { 125 | debug!("waiting for {tasks} task(s) to finish"); 126 | } 127 | task_monitor.closed().await; 128 | } 129 | 130 | Ok(()) 131 | } 132 | } 133 | 134 | fn states( 135 | future: Option + Send + 'static>, 136 | ) -> (ListenerState, ConnectionState) { 137 | let future = match future { 138 | Some(future) => future, 139 | None => return (ListenerState::default(), ConnectionState::default()), 140 | }; 141 | 142 | let (shutting_down, signal) = watch::channel(()); // Axum: `signal_tx`, `signal_rx` 143 | let shutting_down = Arc::new(shutting_down); 144 | tokio::spawn(async move { 145 | future.await; 146 | info!("shutdown signal received, draining..."); 147 | drop(signal); 148 | }); 149 | 150 | let (task_monitor, task_done) = watch::channel(()); // Axum: `close_tx`, `close_rx` 151 | ( 152 | ListenerState { 153 | shutting_down: Some(shutting_down.clone()), 154 | task_monitor: Some(task_monitor), 155 | }, 156 | ConnectionState { 157 | shutting_down: Some(shutting_down), 158 | _task_done: Some(task_done), 159 | }, 160 | ) 161 | } 162 | 163 | #[derive(Default)] 164 | struct ListenerState { 165 | /// If `Some` and `closed()`, the server is shutting down 166 | shutting_down: Option>>, 167 | /// If `Some`, `receiver_count()` can be used whether any connections are still going 168 | /// 169 | /// Call `closed().await` to wait for all connections to finish. 170 | task_monitor: Option>, 171 | } 172 | 173 | struct Connection { 174 | stream: TcpStream, 175 | addr: SocketAddr, 176 | state: ConnectionState, 177 | app: Arc, 178 | } 179 | 180 | impl Connection 181 | where 182 | A::RequestBody: From, 183 | A::ResponseBody: From<&'static str> + Send, 184 | ::Data: Send, 185 | ::Error: StdError + Send + Sync, 186 | { 187 | async fn run(self) { 188 | let Connection { 189 | stream, 190 | addr, 191 | state, 192 | app, 193 | } = self; 194 | 195 | let service = ConnectionService { addr, app }; 196 | 197 | let builder = Builder::new(TokioExecutor::new()); 198 | let stream = TokioIo::new(stream); 199 | let mut conn = pin!(builder.serve_connection_with_upgrades(stream, service)); 200 | let mut shutting_down = pin!(async move { 201 | match state.shutting_down { 202 | Some(shutting_down) => shutting_down.closed().await, 203 | None => pending().await, 204 | } 205 | } 206 | .fuse()); 207 | 208 | loop { 209 | tokio::select! { 210 | result = conn.as_mut() => { 211 | if let Err(error) = result { 212 | error!(%addr, %error, "failed to serve connection"); 213 | } 214 | break; 215 | } 216 | _ = shutting_down.as_mut() => { 217 | debug!("shutting down connection to {addr}"); 218 | conn.as_mut().graceful_shutdown(); 219 | } 220 | } 221 | } 222 | 223 | debug!("connection to {addr} closed"); 224 | } 225 | } 226 | 227 | #[derive(Clone, Default)] 228 | struct ConnectionState { 229 | /// If `Some` and `closed()`, the server is shutting down; don't accept new requests 230 | shutting_down: Option>>, 231 | /// Keeping this around will allow the server to wait for the connection to finish 232 | _task_done: Option>, 233 | } 234 | 235 | pub struct ConnectionService { 236 | addr: SocketAddr, 237 | app: Arc, 238 | } 239 | 240 | impl Service> for ConnectionService 241 | where 242 | A::RequestBody: From, 243 | A::ResponseBody: From<&'static str>, 244 | { 245 | type Response = Response; 246 | type Error = Infallible; 247 | type Future = UnwindSafeHandlerFuture; 248 | 249 | fn call(&self, mut req: Request) -> Self::Future { 250 | req.extensions_mut().insert(ClientAddr(self.addr)); 251 | let cx = Context::new(self.app.clone(), req.map(|body| body.into())); 252 | AssertUnwindSafe(A::handle(cx)) 253 | .catch_unwind() 254 | .map(panic_response) 255 | } 256 | } 257 | 258 | type UnwindSafeHandlerFuture = Map< 259 | CatchUnwind + Send>>>>, 260 | fn(Result>) -> Result, 261 | >; 262 | 263 | fn panic_response>( 264 | result: Result, Box>, 265 | ) -> Result, Infallible> { 266 | #[allow(unused_variables)] // Depends on features 267 | let error = match result { 268 | Ok(rsp) => return Ok(rsp), 269 | Err(e) => e, 270 | }; 271 | 272 | #[cfg(feature = "tracing")] 273 | { 274 | let panic_str = if let Some(s) = error.downcast_ref::() { 275 | Some(s.as_str()) 276 | } else if let Some(s) = error.downcast_ref::<&'static str>() { 277 | Some(*s) 278 | } else { 279 | Some("no error") 280 | }; 281 | 282 | tracing::error!("caught panic from request handler: {:?}", panic_str); 283 | } 284 | 285 | Ok(Response::builder() 286 | .status(StatusCode::INTERNAL_SERVER_ERROR) 287 | .body("Caught panic".into()) 288 | .unwrap()) 289 | } 290 | 291 | impl<'a, A: Application> FromContext<'a, A> for Incoming { 292 | fn from_context( 293 | _: &'a Arc, 294 | _: &'a Parts, 295 | _: &mut PathState, 296 | body: &mut Option, 297 | ) -> Result { 298 | match body.take() { 299 | Some(body) => Ok(body), 300 | None => panic!("attempted to retrieve body twice"), 301 | } 302 | } 303 | } 304 | 305 | impl<'a, A: Application> FromContext<'a, A> for ClientAddr { 306 | fn from_context( 307 | _: &'a Arc, 308 | req: &'a Parts, 309 | _: &mut PathState, 310 | _: &mut Option, 311 | ) -> Result { 312 | // This is safe because we insert ClientAddr into the request extensions 313 | // unconditionally in the ConnectionService::call method. 314 | Ok(req.extensions.get::().copied().unwrap()) 315 | } 316 | } 317 | 318 | #[derive(Debug, Clone, Copy)] 319 | pub struct ClientAddr(SocketAddr); 320 | 321 | impl std::ops::Deref for ClientAddr { 322 | type Target = SocketAddr; 323 | 324 | fn deref(&self) -> &Self::Target { 325 | &self.0 326 | } 327 | } 328 | 329 | impl From for ClientAddr { 330 | fn from(addr: SocketAddr) -> Self { 331 | Self(addr) 332 | } 333 | } 334 | -------------------------------------------------------------------------------- /mendes-macros/src/forms.rs: -------------------------------------------------------------------------------- 1 | use std::mem; 2 | 3 | use proc_macro2::Span; 4 | use quote::{quote, ToTokens}; 5 | use syn::parse::{Parse, ParseStream}; 6 | use syn::punctuated::Punctuated; 7 | use syn::token::Comma; 8 | 9 | pub fn form(meta: &FormMeta, ast: &mut syn::ItemStruct) -> proc_macro2::TokenStream { 10 | let fields = match &mut ast.fields { 11 | syn::Fields::Named(fields) => fields, 12 | _ => panic!("only structs with named fields are supported"), 13 | }; 14 | 15 | let mut item_state = None; 16 | let mut new = proc_macro2::TokenStream::new(); 17 | for field in fields.named.iter_mut() { 18 | let name = field.ident.as_ref().unwrap().to_string(); 19 | let mut label = { 20 | let label = syn::LitStr::new(&label(&name), Span::call_site()); 21 | quote!(Some(#label.into())) 22 | }; 23 | let mut item = None; 24 | let mut skip = false; 25 | 26 | let params = if let Some((i, attr)) = field 27 | .attrs 28 | .iter_mut() 29 | .enumerate() 30 | .find(|(_, a)| a.path().is_ident("form")) 31 | { 32 | let input = match &mut attr.meta { 33 | syn::Meta::List(list) => { 34 | mem::replace(&mut list.tokens, proc_macro2::TokenStream::new()) 35 | } 36 | _ => panic!("expected list in form attribute"), 37 | }; 38 | 39 | let mut tokens = proc_macro2::TokenStream::new(); 40 | for (key, value) in syn::parse2::(input).unwrap().params { 41 | if key == "type" && value == "hidden" { 42 | label = quote!(None); 43 | } else if key == "label" { 44 | label = quote!(Some(#value.into())); 45 | } else if key == "item" { 46 | item = Some(value.clone()); 47 | } else if key == "skip" { 48 | skip = true; 49 | } 50 | tokens.extend(quote!( 51 | (#key, #value), 52 | )); 53 | } 54 | field.attrs.remove(i); 55 | tokens 56 | } else { 57 | quote!() 58 | }; 59 | 60 | if skip { 61 | continue; 62 | } 63 | 64 | let ty = &field.ty; 65 | let tokens = quote!( 66 | mendes::forms::Item { 67 | label: #label, 68 | contents: mendes::forms::ItemContents::Single( 69 | <#ty as mendes::forms::ToField>::to_field(#name.into(), &[#params]) 70 | ), 71 | }, 72 | ); 73 | 74 | item_state = match item_state { 75 | None if item.is_none() => { 76 | new.extend(tokens); 77 | None 78 | } 79 | None => Some((item.unwrap(), tokens)), 80 | Some((name, mut items)) => match item { 81 | Some(cur) if cur == name => { 82 | items.extend(tokens); 83 | Some((name, items)) 84 | } 85 | Some(cur) => { 86 | let label = syn::LitStr::new(&name, Span::call_site()); 87 | new.extend(quote!( 88 | mendes::forms::Item { 89 | label: Some(#label.into()), 90 | contents: mendes::forms::ItemContents::Multi(vec![#items]), 91 | }, 92 | )); 93 | Some((cur, tokens)) 94 | } 95 | None => { 96 | let label = syn::LitStr::new(&name, Span::call_site()); 97 | new.extend(quote!( 98 | mendes::forms::Item { 99 | label: Some(#label.into()), 100 | contents: mendes::forms::ItemContents::Multi(vec![#items]), 101 | }, 102 | )); 103 | new.extend(tokens); 104 | None 105 | } 106 | }, 107 | } 108 | } 109 | 110 | let FormMeta { 111 | action, 112 | classes, 113 | submit, 114 | } = &meta; 115 | let submit = match submit { 116 | Some(s) => quote!(Some(#s.into())), 117 | None => quote!(None), 118 | }; 119 | 120 | new.extend(quote!( 121 | mendes::forms::Item { 122 | label: None, 123 | contents: mendes::forms::ItemContents::Single( 124 | mendes::forms::Field::Submit(mendes::forms::Submit { 125 | value: #submit, 126 | }) 127 | ), 128 | }, 129 | )); 130 | 131 | let action = match action { 132 | Some(s) => quote!(Some(#s.into())), 133 | None => quote!(None), 134 | }; 135 | 136 | let name = &ast.ident; 137 | let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl(); 138 | let display = quote!( 139 | impl #impl_generics mendes::forms::ToForm for #name #type_generics #where_clause { 140 | fn to_form() -> mendes::forms::Form { 141 | mendes::forms::Form { 142 | action: #action, 143 | enctype: None, 144 | method: Some("post".into()), 145 | classes: #classes, 146 | sets: vec![ 147 | mendes::forms::FieldSet { 148 | legend: None, 149 | items: vec![ 150 | #new 151 | ], 152 | } 153 | ], 154 | }.prepare() 155 | } 156 | } 157 | ); 158 | 159 | display 160 | } 161 | 162 | pub struct FormMeta { 163 | action: Option, 164 | submit: Option, 165 | classes: proc_macro2::TokenStream, 166 | } 167 | 168 | impl Parse for FormMeta { 169 | fn parse(input: ParseStream) -> syn::Result { 170 | let (mut action, mut submit, mut classes) = (None, None, quote!(vec![])); 171 | for field in Punctuated::::parse_terminated(input)? { 172 | let value = match field.value { 173 | syn::Expr::Lit(v) => v, 174 | _ => panic!( 175 | "expected literal value for key {:?}", 176 | field.path.to_token_stream() 177 | ), 178 | }; 179 | 180 | if field.path.is_ident("action") { 181 | match value.lit { 182 | syn::Lit::Str(v) => { 183 | action = Some(v.value()); 184 | } 185 | _ => panic!("expected string value for key 'action'"), 186 | } 187 | } else if field.path.is_ident("submit") { 188 | match value.lit { 189 | syn::Lit::Str(v) => { 190 | submit = Some(v.value()); 191 | } 192 | _ => panic!("expected string value for key 'submit'"), 193 | } 194 | } else if field.path.is_ident("class") { 195 | match value.lit { 196 | syn::Lit::Str(v) => { 197 | let val = v.value(); 198 | let iter = val.split(' '); 199 | classes = quote!(vec![#(#iter.into()),*]); 200 | } 201 | _ => panic!("expected string value for key 'class'"), 202 | } 203 | } else { 204 | panic!("unexpected field {:?}", field.path.to_token_stream()); 205 | } 206 | } 207 | 208 | Ok(Self { 209 | action, 210 | submit, 211 | classes, 212 | }) 213 | } 214 | } 215 | 216 | pub fn to_field(mut ast: syn::DeriveInput) -> proc_macro2::TokenStream { 217 | let item = match &mut ast.data { 218 | syn::Data::Enum(item) => item, 219 | _ => panic!("only enums can derive ToField for now"), 220 | }; 221 | 222 | let mut options = proc_macro2::TokenStream::new(); 223 | for variant in item.variants.iter_mut() { 224 | match variant.fields { 225 | syn::Fields::Unit => {} 226 | _ => panic!("only unit variants are supported for now"), 227 | }; 228 | 229 | let params = if let Some((i, attr)) = variant 230 | .attrs 231 | .iter_mut() 232 | .enumerate() 233 | .find(|(_, a)| a.path().is_ident("option")) 234 | { 235 | let input = match &mut attr.meta { 236 | syn::Meta::List(list) => { 237 | mem::replace(&mut list.tokens, proc_macro2::TokenStream::new()) 238 | } 239 | _ => panic!("expected list in form attribute"), 240 | }; 241 | 242 | let params = syn::parse2::(input).unwrap().params; 243 | variant.attrs.remove(i); 244 | params 245 | } else { 246 | vec![] 247 | }; 248 | 249 | let name = variant.ident.to_string(); 250 | let label = params 251 | .iter() 252 | .find_map(|(key, value)| { 253 | if key == "label" { 254 | Some(quote!(#value.into())) 255 | } else { 256 | None 257 | } 258 | }) 259 | .unwrap_or_else(|| quote!(#name.into())); 260 | 261 | options.extend(quote!( 262 | mendes::forms::SelectOption { 263 | label: #label, 264 | value: #name.into(), 265 | disabled: false, 266 | selected: false, 267 | }, 268 | )); 269 | } 270 | 271 | let ident = &ast.ident; 272 | quote!( 273 | impl ToField for #ident { 274 | fn to_field(name: std::borrow::Cow<'static, str>, _: &[(&str, &str)]) -> mendes::forms::Field { 275 | mendes::forms::Field::Select(mendes::forms::Select { 276 | name, 277 | options: vec![#options], 278 | }) 279 | } 280 | } 281 | ) 282 | } 283 | 284 | pub struct FieldParams { 285 | pub params: Vec<(String, String)>, 286 | } 287 | 288 | impl Parse for FieldParams { 289 | fn parse(input: ParseStream) -> syn::Result { 290 | Ok(Self { 291 | params: Punctuated::::parse_terminated(input)? 292 | .into_iter() 293 | .map(|meta| match meta { 294 | syn::Meta::NameValue(meta) => { 295 | let key = meta.path.get_ident().unwrap().to_string(); 296 | let value = meta.value.into_token_stream().to_string(); 297 | let value = value.trim_matches('"').to_string(); 298 | (key, value) 299 | } 300 | syn::Meta::Path(path) => { 301 | let key = path.get_ident().unwrap().to_string(); 302 | (key, "true".into()) 303 | } 304 | _ => unimplemented!(), 305 | }) 306 | .collect(), 307 | }) 308 | } 309 | } 310 | 311 | fn label(s: &str) -> String { 312 | let mut new = String::with_capacity(s.len()); 313 | for (i, c) in s.chars().enumerate() { 314 | if i == 0 { 315 | new.extend(c.to_uppercase()); 316 | } else if c == '_' { 317 | new.push(' '); 318 | } else { 319 | new.push(c); 320 | } 321 | } 322 | new 323 | } 324 | -------------------------------------------------------------------------------- /mendes/src/cookies.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "application")] 2 | use std::convert::TryFrom; 3 | #[cfg(feature = "application")] 4 | use std::fmt::Write; 5 | use std::str; 6 | #[cfg(feature = "application")] 7 | use std::time::Duration; 8 | use std::time::SystemTime; 9 | 10 | use data_encoding::BASE64URL_NOPAD; 11 | use http::header::InvalidHeaderValue; 12 | #[cfg(feature = "application")] 13 | use http::header::COOKIE; 14 | #[cfg(feature = "application")] 15 | use http::{HeaderMap, HeaderValue}; 16 | use serde::{de::DeserializeOwned, Deserialize, Serialize}; 17 | use thiserror::Error; 18 | 19 | #[cfg(feature = "application")] 20 | use crate::key::{NONCE_LEN, TAG_LEN}; 21 | 22 | pub use crate::key::Key; 23 | pub use mendes_macros::cookie; 24 | 25 | #[cfg(feature = "application")] 26 | #[cfg_attr(docsrs, doc(cfg(feature = "application")))] 27 | pub use application::{AppWithAeadKey, AppWithCookies}; 28 | 29 | #[cfg(feature = "application")] 30 | #[cfg_attr(docsrs, doc(cfg(feature = "application")))] 31 | mod application { 32 | use super::*; 33 | use http::header::SET_COOKIE; 34 | 35 | pub use crate::key::AppWithAeadKey; 36 | 37 | /// Cookie manipulation methods, contingent upon the `Application`'s access to an AEAD `Key` 38 | pub trait AppWithCookies: AppWithAeadKey { 39 | /// Extract cookie from the given `HeaderMap` using this `Application`'s `Key` 40 | /// 41 | /// Finds the first `Cookie` header whose name matches the given type `T` and 42 | /// whose value can be successfully decoded, decrypted and has not expired. 43 | fn cookie(&self, headers: &HeaderMap) -> Option { 44 | extract(self.key(), headers) 45 | } 46 | 47 | /// Set cookie value by appending a `Set-Cookie` to the given `HeaderMap` 48 | /// 49 | /// If `data` is `Some`, a new value will be set. If the value is `None`, an 50 | /// empty value is set with an expiry time in the past, causing the cookie 51 | /// to be deleted in compliant clients. 52 | fn set_cookie( 53 | &self, 54 | headers: &mut HeaderMap, 55 | data: Option, 56 | ) -> Result<(), Error> { 57 | headers.append(SET_COOKIE, self.set_cookie_header(data)?); 58 | Ok(()) 59 | } 60 | 61 | /// Encode and encrypt the cookie's value into a `Set-Cookie` `HeaderValue` 62 | /// 63 | /// If `data` is `Some`, a new value will be set. If the value is `None`, an 64 | /// empty value is set with an expiry time in the past, causing the cookie 65 | /// to be deleted in compliant clients. 66 | fn set_cookie_header( 67 | &self, 68 | data: Option, 69 | ) -> Result { 70 | self.set_cookie_from_parts(T::NAME, data, &T::meta()) 71 | } 72 | 73 | /// Assemble a `Set-Cookie` `HeaderValue` from parts 74 | fn set_cookie_from_parts( 75 | &self, 76 | name: &str, 77 | value: Option, 78 | meta: &CookieMeta<'_>, 79 | ) -> Result { 80 | let value = value 81 | .map(|data| Cookie::encode(name, data, meta, self.key())) 82 | .transpose()?; 83 | cookie(name, value.as_deref(), meta) 84 | } 85 | } 86 | 87 | impl AppWithCookies for A {} 88 | } 89 | 90 | /// Data to be stored in a cookie 91 | /// 92 | /// This is usually derived through the `cookie` procedural attribute macro. 93 | pub trait CookieData { 94 | fn decode(value: &str, key: &Key) -> Option 95 | where 96 | Self: DeserializeOwned, 97 | { 98 | let mut bytes = BASE64URL_NOPAD.decode(value.as_bytes()).ok()?; 99 | let plain = key.decrypt(Self::NAME.as_bytes(), &mut bytes).ok()?; 100 | 101 | let cookie = postcard::from_bytes::>(plain).ok()?; 102 | match SystemTime::now() < cookie.expires { 103 | true => Some(cookie.data), 104 | false => None, 105 | } 106 | } 107 | 108 | fn meta() -> CookieMeta<'static> { 109 | CookieMeta::default() 110 | } 111 | 112 | /// The name to use for the cookie 113 | const NAME: &'static str; 114 | } 115 | 116 | pub struct CookieMeta<'a> { 117 | /// Defines the host to which the cookie will be sent 118 | pub domain: Option<&'a str>, 119 | /// Forbid JavaScript access to the cookie 120 | /// 121 | /// Defaults to `false`. 122 | pub http_only: bool, 123 | /// The maximum age for the cookie in seconds 124 | /// 125 | /// Defaults to 6 hours. 126 | pub max_age: u32, 127 | /// Set a path prefix to constrain use of the cookie 128 | /// 129 | /// The browser default here is to use the current directory (removing the last path 130 | /// segment from the current URL), which seems pretty useless. Instead, we default to `/` here. 131 | pub path: &'a str, 132 | /// Controls whether the cookie is sent with cross-origin requests 133 | /// 134 | /// Defaults to `Some(SameSite::None)`. 135 | pub same_site: Option, 136 | /// Restrict the cookie to being sent only over secure connections 137 | /// 138 | /// Defaults to `true`. 139 | pub secure: bool, 140 | } 141 | 142 | impl Default for CookieMeta<'static> { 143 | fn default() -> Self { 144 | Self { 145 | domain: None, 146 | http_only: false, 147 | max_age: 6 * 60 * 60, 148 | path: "/", 149 | same_site: Some(SameSite::None), 150 | secure: true, 151 | } 152 | } 153 | } 154 | 155 | #[derive(Deserialize, Serialize)] 156 | #[serde(bound(deserialize = "T: DeserializeOwned"))] 157 | struct Cookie { 158 | expires: SystemTime, 159 | data: T, 160 | } 161 | 162 | #[cfg(feature = "application")] 163 | impl Cookie { 164 | fn encode(name: &str, data: T, meta: &CookieMeta<'_>, key: &Key) -> Result { 165 | let expires = SystemTime::now() 166 | .checked_add(Duration::new(meta.max_age as u64, 0)) 167 | .ok_or(Error::ExpiryWindowTooLong)?; 168 | 169 | let mut bytes = postcard::to_stdvec(&Cookie { expires, data })?; 170 | key.encrypt(name.as_bytes(), &mut bytes)?; 171 | Ok(BASE64URL_NOPAD.encode(&bytes)) 172 | } 173 | } 174 | 175 | #[cfg(feature = "application")] 176 | fn extract(key: &Key, headers: &HeaderMap) -> Option { 177 | let name = T::NAME; 178 | // HTTP/2 allows for multiple cookie headers. 179 | // https://datatracker.ietf.org/doc/html/rfc9113#name-compressing-the-cookie-head 180 | for value in headers.get_all(COOKIE) { 181 | let value = match str::from_utf8(value.as_ref()) { 182 | Ok(value) => value, 183 | Err(_) => continue, 184 | }; 185 | // A single cookie header can contain multiple cookies (delimited by ;) 186 | // even if there are multiple cookie headers. 187 | for cookie in value.split(';') { 188 | let cookie = cookie.trim_start(); 189 | if cookie.len() < (name.len() + 1 + NONCE_LEN + TAG_LEN) 190 | || !cookie.starts_with(name) 191 | || cookie.as_bytes()[name.len()] != b'=' 192 | { 193 | continue; 194 | } 195 | 196 | let encoded = &cookie[name.len() + 1..]; 197 | match T::decode(encoded, key) { 198 | Some(data) => return Some(data), 199 | None => continue, 200 | } 201 | } 202 | } 203 | None 204 | } 205 | 206 | #[cfg(feature = "application")] 207 | fn cookie(name: &str, value: Option<&str>, meta: &CookieMeta<'_>) -> Result { 208 | let mut s = match value { 209 | Some(value) => format!( 210 | "{}={}; Max-Age={}; Path={}", 211 | name, value, meta.max_age, meta.path, 212 | ), 213 | None => format!( 214 | "{}=None; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Path={}", 215 | name, meta.path, 216 | ), 217 | }; 218 | 219 | if let Some(domain) = meta.domain { 220 | write!(s, "; Domain={domain}").unwrap(); 221 | } 222 | 223 | if meta.http_only { 224 | write!(s, "; HttpOnly").unwrap(); 225 | } 226 | 227 | if let Some(same_site) = meta.same_site { 228 | write!(s, "; SameSite={same_site:?}").unwrap(); 229 | } 230 | 231 | if meta.secure { 232 | write!(s, "; Secure").unwrap(); 233 | } 234 | 235 | Ok(HeaderValue::try_from(s)?) 236 | } 237 | 238 | #[derive(Debug, Clone, Copy)] 239 | pub enum SameSite { 240 | Lax, 241 | None, 242 | Strict, 243 | } 244 | 245 | #[derive(Debug, Error)] 246 | pub enum Error { 247 | #[error("unstable to serialize cookie data")] 248 | DataSerializationFailed(#[from] postcard::Error), 249 | #[error("expiry window too long")] 250 | ExpiryWindowTooLong, 251 | #[error("non-ASCII cookie name")] 252 | InvalidCookieName(#[from] InvalidHeaderValue), 253 | #[error("key error: {0}")] 254 | Key(#[from] crate::key::Error), 255 | } 256 | 257 | #[cfg(test)] 258 | mod test { 259 | use http::{header, HeaderMap}; 260 | use serde::{Deserialize, Serialize}; 261 | 262 | use super::*; 263 | 264 | /// This test checks that we can extract a cookie from a request that uses multiple cookies in a single header 265 | #[test] 266 | fn test_multiple_cookies_in_single_header() { 267 | let key = crate::key::Key::from_hex_lower( 268 | b"db9881d396644d64818c0bc192d161addb9881d396644d64818c0bc192d161ad", 269 | ) 270 | .unwrap(); 271 | let session = Session { id: 2 }; 272 | 273 | let mut headers = HeaderMap::new(); 274 | let meta = Session::meta(); 275 | let cookie_value = Cookie::encode(Session::NAME, session, &meta, &key).unwrap(); 276 | let header_value = format!("_internal_s=logs=1&id=toast;Session={cookie_value};RefreshToken=tWEnTuXNfmCV_ZNYZQXvMeZ8AN5KUqas7vsqY1wwcWa6TfxYEqekcBVIpagFXn06XsHSN8GZQqGi2w1jd2Atj-aEwNq2wknQjpmxFKIMAnOYFd6gcCoG6Q").parse().unwrap(); 277 | headers.insert(header::COOKIE, header_value); 278 | 279 | assert_eq!(super::extract::(&key, &headers).unwrap().id, 2); 280 | } 281 | 282 | /// This test checks that we can extract a cookie from a request that uses separate headers for each cookie 283 | #[test] 284 | fn test_separate_cookie_headers() { 285 | let key = crate::key::Key::from_hex_lower( 286 | b"db9881d396644d64818c0bc192d161addb9881d396644d64818c0bc192d161ad", 287 | ) 288 | .unwrap(); 289 | let session = Session { id: 2 }; 290 | 291 | let mut headers = HeaderMap::new(); 292 | headers.insert( 293 | header::COOKIE, 294 | "_internal_s=logs=1&id=toast;".parse().unwrap(), 295 | ); 296 | 297 | let meta = Session::meta(); 298 | let cookie_value = Cookie::encode(Session::NAME, session, &meta, &key).unwrap(); 299 | headers.append( 300 | header::COOKIE, 301 | format!("Session={cookie_value}").parse().unwrap(), 302 | ); 303 | headers.append(header::COOKIE, "RefreshToken=tWEnTuXNfmCV_ZNYZQXvMeZ8AN5KUqas7vsqY1wwcWa6TfxYEqekcBVIpagFXn06XsHSN8GZQqGi2w1jd2Atj-aEwNq2wknQjpmxFKIMAnOYFd6gcCoG6Q".parse().unwrap()); 304 | 305 | assert_eq!(super::extract::(&key, &headers).unwrap().id, 2); 306 | } 307 | 308 | #[derive(Clone, Copy, Debug, Deserialize, Serialize)] 309 | pub struct Session { 310 | id: i64, 311 | } 312 | 313 | impl super::CookieData for Session { 314 | const NAME: &'static str = "Session"; 315 | } 316 | } 317 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /mendes/src/body.rs: -------------------------------------------------------------------------------- 1 | use std::future::Future; 2 | use std::pin::Pin; 3 | use std::str::FromStr; 4 | use std::sync::Arc; 5 | use std::task::ready; 6 | use std::task::Poll; 7 | use std::{io, mem, str}; 8 | 9 | #[cfg(feature = "brotli")] 10 | use async_compression::tokio::bufread::BrotliEncoder; 11 | #[cfg(feature = "gzip")] 12 | use async_compression::tokio::bufread::GzipEncoder; 13 | #[cfg(feature = "zlib")] 14 | use async_compression::tokio::bufread::ZlibEncoder; 15 | use bytes::{Buf, Bytes, BytesMut}; 16 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 17 | use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING}; 18 | use http::request::Parts; 19 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 20 | use http::HeaderMap; 21 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 22 | use http::{request, HeaderValue, Response}; 23 | use http_body::{Frame, SizeHint}; 24 | use pin_project::pin_project; 25 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 26 | use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; 27 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 28 | use tokio_util::io::poll_read_buf; 29 | 30 | use crate::application::{Application, FromContext, PathState}; 31 | 32 | #[pin_project] 33 | pub struct Body { 34 | #[pin] 35 | inner: InnerBody, 36 | full_size: u64, 37 | done: bool, 38 | } 39 | 40 | impl Body { 41 | pub fn empty() -> Self { 42 | Self { 43 | inner: InnerBody::Bytes(Bytes::new()), 44 | full_size: 0, 45 | done: true, 46 | } 47 | } 48 | 49 | pub fn lazy(future: impl Future> + Send + 'static) -> Self { 50 | Self { 51 | inner: InnerBody::Lazy { 52 | future: Box::pin(future), 53 | encoding: Encoding::Identity, 54 | }, 55 | full_size: 0, 56 | done: false, 57 | } 58 | } 59 | 60 | pub fn stream( 61 | stream: impl http_body::Body + Send + 'static, 62 | ) -> Self { 63 | Self { 64 | inner: InnerBody::Streaming(Box::pin(stream)), 65 | full_size: 0, 66 | done: false, 67 | } 68 | } 69 | } 70 | 71 | impl<'a, A: Application> FromContext<'a, A> for Body { 72 | fn from_context( 73 | _: &'a Arc, 74 | _: &'a Parts, 75 | _: &mut PathState, 76 | body: &mut Option, 77 | ) -> Result { 78 | match body.take() { 79 | Some(body) => Ok(body), 80 | None => panic!("attempted to retrieve body twice"), 81 | } 82 | } 83 | } 84 | 85 | impl http_body::Body for Body { 86 | type Data = Bytes; 87 | type Error = io::Error; 88 | 89 | #[allow(unused_variables)] // Depends on features 90 | fn poll_frame( 91 | self: Pin<&mut Self>, 92 | cx: &mut std::task::Context<'_>, 93 | ) -> Poll, Self::Error>>> { 94 | let this = self.project(); 95 | if *this.done { 96 | return Poll::Ready(None); 97 | } 98 | 99 | #[allow(unused_mut)] // Depends on features 100 | let mut buf = BytesMut::new(); 101 | let result = match this.inner.project() { 102 | #[cfg(feature = "brotli")] 103 | PinnedBody::Brotli(encoder) => poll_read_buf(encoder, cx, &mut buf), 104 | #[cfg(feature = "gzip")] 105 | PinnedBody::Gzip(encoder) => poll_read_buf(encoder, cx, &mut buf), 106 | #[cfg(feature = "zlib")] 107 | PinnedBody::Zlib(encoder) => poll_read_buf(encoder, cx, &mut buf), 108 | PinnedBody::Bytes(bytes) => { 109 | *this.done = true; 110 | let bytes = mem::take(bytes.get_mut()); 111 | return Poll::Ready(match bytes.has_remaining() { 112 | true => Some(Ok(Frame::data(bytes))), 113 | false => None, 114 | }); 115 | } 116 | #[cfg(feature = "hyper")] 117 | PinnedBody::Hyper(mut inner) => { 118 | return Poll::Ready(match ready!(inner.as_mut().poll_frame(cx)) { 119 | Some(Ok(frame)) => Some(Ok(frame)), 120 | Some(Err(error)) => Some(Err(io::Error::other(error))), 121 | None => { 122 | *this.done = true; 123 | None 124 | } 125 | }) 126 | } 127 | PinnedBody::Streaming(inner) => match ready!(inner.as_mut().poll_frame(cx)) { 128 | Some(item) => return Poll::Ready(Some(item)), 129 | None => { 130 | *this.done = true; 131 | return Poll::Ready(None); 132 | } 133 | }, 134 | PinnedBody::Lazy { future, encoding } => { 135 | let bytes = match ready!(future.as_mut().poll(cx)) { 136 | Ok(bytes) => bytes, 137 | Err(error) => return Poll::Ready(Some(Err(error))), 138 | }; 139 | 140 | let len = bytes.len(); 141 | let mut inner = InnerBody::wrap(bytes, *encoding); 142 | *this.full_size = len as u64; 143 | // The duplication here is pretty ugly, but I couldn't come up with anything better. 144 | match &mut inner { 145 | #[cfg(feature = "brotli")] 146 | InnerBody::Brotli(encoder) => poll_read_buf(Pin::new(encoder), cx, &mut buf), 147 | #[cfg(feature = "gzip")] 148 | InnerBody::Gzip(encoder) => poll_read_buf(Pin::new(encoder), cx, &mut buf), 149 | #[cfg(feature = "zlib")] 150 | InnerBody::Zlib(encoder) => poll_read_buf(Pin::new(encoder), cx, &mut buf), 151 | InnerBody::Bytes(bytes) => { 152 | *this.done = true; 153 | let bytes = mem::take(bytes); 154 | return Poll::Ready(match bytes.has_remaining() { 155 | true => Some(Ok(Frame::data(bytes))), 156 | false => None, 157 | }); 158 | } 159 | #[cfg(feature = "hyper")] 160 | InnerBody::Hyper(_) => unreachable!(), 161 | InnerBody::Lazy { .. } | InnerBody::Streaming(_) => { 162 | unreachable!() 163 | } 164 | } 165 | } 166 | }; 167 | 168 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 169 | match ready!(result) { 170 | Ok(0) => { 171 | *this.done = true; 172 | Poll::Ready(None) 173 | } 174 | Ok(n) => { 175 | *this.full_size = this.full_size.saturating_sub(n as u64); 176 | Poll::Ready(Some(Ok(Frame::data(buf.freeze())))) 177 | } 178 | Err(error) => Poll::Ready(Some(Err(error))), 179 | } 180 | } 181 | 182 | fn is_end_stream(&self) -> bool { 183 | self.done 184 | } 185 | 186 | fn size_hint(&self) -> http_body::SizeHint { 187 | match (self.done, &self.inner) { 188 | (true, _) => SizeHint::with_exact(0), 189 | (false, InnerBody::Bytes(body)) => SizeHint::with_exact(body.len() as u64), 190 | #[cfg(feature = "hyper")] 191 | (false, InnerBody::Hyper(inner)) => inner.size_hint(), 192 | (false, InnerBody::Lazy { .. } | InnerBody::Streaming(_)) => SizeHint::default(), 193 | // The duplication here is pretty ugly, but I couldn't come up with anything better. 194 | #[cfg(feature = "brotli")] 195 | (false, InnerBody::Brotli(_)) => { 196 | let mut hint = SizeHint::default(); 197 | hint.set_lower(1); 198 | hint.set_upper(self.full_size + 256); 199 | hint 200 | } 201 | #[cfg(feature = "gzip")] 202 | (false, InnerBody::Gzip(_)) => { 203 | let mut hint = SizeHint::default(); 204 | hint.set_lower(1); 205 | hint.set_upper(self.full_size + 256); 206 | hint 207 | } 208 | #[cfg(feature = "zlib")] 209 | (false, InnerBody::Zlib(_)) => { 210 | let mut hint = SizeHint::default(); 211 | hint.set_lower(1); 212 | hint.set_upper(self.full_size + 256); 213 | hint 214 | } 215 | } 216 | } 217 | } 218 | 219 | #[cfg(feature = "hyper")] 220 | impl From for Body { 221 | fn from(inner: hyper::body::Incoming) -> Self { 222 | Self { 223 | inner: InnerBody::Hyper(inner), 224 | full_size: 0, 225 | done: false, 226 | } 227 | } 228 | } 229 | 230 | impl From> for Body { 231 | fn from(data: Vec) -> Self { 232 | Self::from(Bytes::from(data)) 233 | } 234 | } 235 | 236 | impl From for Body { 237 | fn from(data: String) -> Self { 238 | Self::from(Bytes::from(data)) 239 | } 240 | } 241 | 242 | impl From<&'static str> for Body { 243 | fn from(data: &'static str) -> Self { 244 | Self::from(Bytes::from(data)) 245 | } 246 | } 247 | 248 | impl From for Body { 249 | fn from(data: Bytes) -> Self { 250 | Self { 251 | done: !data.has_remaining(), 252 | full_size: data.len() as u64, 253 | inner: InnerBody::Bytes(data), 254 | } 255 | } 256 | } 257 | 258 | impl Default for Body { 259 | fn default() -> Self { 260 | Self::empty() 261 | } 262 | } 263 | 264 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 265 | impl EncodeResponse for Response { 266 | fn encoded(mut self, req: &request::Parts) -> Response { 267 | let buf = match self.body_mut() { 268 | Body { done: true, .. } => return self, 269 | Body { 270 | inner: InnerBody::Bytes(buf), 271 | .. 272 | } => mem::take(buf), 273 | Body { 274 | inner: 275 | InnerBody::Lazy { 276 | encoding: enc @ Encoding::Identity, 277 | .. 278 | }, 279 | .. 280 | } => { 281 | let new = Encoding::from_accept(&req.headers).unwrap_or(Encoding::Identity); 282 | *enc = new; 283 | return self; 284 | } 285 | Body { .. } => return self, 286 | }; 287 | 288 | let len = buf.len(); 289 | let encoding = Encoding::from_accept(&req.headers).unwrap_or(Encoding::Identity); 290 | let inner = InnerBody::wrap(buf, encoding); 291 | if let Some(encoding) = encoding.as_str() { 292 | self.headers_mut() 293 | .insert(CONTENT_ENCODING, HeaderValue::from_static(encoding)); 294 | } 295 | 296 | let body = self.body_mut(); 297 | body.full_size = len as u64; 298 | body.inner = inner; 299 | self 300 | } 301 | } 302 | 303 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 304 | pub trait EncodeResponse { 305 | fn encoded(self, req: &request::Parts) -> Self; 306 | } 307 | 308 | #[pin_project(project = PinnedBody)] 309 | enum InnerBody { 310 | #[cfg(feature = "brotli")] 311 | Brotli(#[pin] Box>), 312 | #[cfg(feature = "gzip")] 313 | Gzip(#[pin] GzipEncoder), 314 | #[cfg(feature = "zlib")] 315 | Zlib(#[pin] ZlibEncoder), 316 | Bytes(#[pin] Bytes), 317 | #[cfg(feature = "hyper")] 318 | Hyper(#[pin] hyper::body::Incoming), 319 | Lazy { 320 | future: Pin> + Send>>, 321 | encoding: Encoding, 322 | }, 323 | Streaming(Pin + Send>>), 324 | } 325 | 326 | impl InnerBody { 327 | fn wrap(buf: Bytes, encoding: Encoding) -> Self { 328 | match encoding { 329 | #[cfg(feature = "brotli")] 330 | Encoding::Brotli => Self::Brotli(Box::new(BrotliEncoder::new(BufReader { buf }))), 331 | #[cfg(feature = "gzip")] 332 | Encoding::Gzip => Self::Gzip(GzipEncoder::new(BufReader { buf })), 333 | #[cfg(feature = "zlib")] 334 | Encoding::Zlib => Self::Zlib(ZlibEncoder::new(BufReader { buf })), 335 | Encoding::Identity => Self::Bytes(buf), 336 | } 337 | } 338 | } 339 | 340 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 341 | struct BufReader { 342 | pub(crate) buf: Bytes, 343 | } 344 | 345 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 346 | impl AsyncBufRead for BufReader { 347 | fn poll_fill_buf( 348 | self: Pin<&mut Self>, 349 | _: &mut std::task::Context<'_>, 350 | ) -> Poll> { 351 | Poll::Ready(Ok(self.get_mut().buf.chunk())) 352 | } 353 | 354 | fn consume(self: Pin<&mut Self>, amt: usize) { 355 | self.get_mut().buf.advance(amt); 356 | } 357 | } 358 | 359 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 360 | impl AsyncRead for BufReader { 361 | fn poll_read( 362 | self: Pin<&mut Self>, 363 | _: &mut std::task::Context<'_>, 364 | buf: &mut ReadBuf<'_>, 365 | ) -> Poll> { 366 | let len = Ord::min(self.buf.remaining(), buf.remaining()); 367 | self.get_mut() 368 | .buf 369 | .copy_to_slice(buf.initialize_unfilled_to(len)); 370 | Poll::Ready(Ok(())) 371 | } 372 | } 373 | 374 | #[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord)] 375 | enum Encoding { 376 | #[cfg(feature = "brotli")] 377 | Brotli, 378 | #[cfg(feature = "gzip")] 379 | Gzip, 380 | #[cfg(feature = "zlib")] 381 | Zlib, 382 | Identity, 383 | } 384 | 385 | impl Encoding { 386 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 387 | fn from_accept(headers: &HeaderMap) -> Option { 388 | let accept = match headers.get(ACCEPT_ENCODING).map(|hv| hv.to_str()) { 389 | Some(Ok(accept)) => accept, 390 | _ => return None, 391 | }; 392 | 393 | let mut encodings = accept 394 | .split(',') 395 | .filter_map(|s| { 396 | let mut parts = s.splitn(2, ';'); 397 | let alg = match Encoding::from_str(parts.next()?.trim()) { 398 | Ok(encoding) => encoding, 399 | Err(()) => return None, 400 | }; 401 | 402 | let qual = parts 403 | .next() 404 | .and_then(|s| { 405 | let mut parts = s.splitn(2, '='); 406 | if parts.next()?.trim() != "q" { 407 | return None; 408 | } 409 | 410 | let value = parts.next()?; 411 | f64::from_str(value).ok() 412 | }) 413 | .unwrap_or(1.0); 414 | 415 | Some((alg, (qual * 100.0) as u64)) 416 | }) 417 | .collect::>(); 418 | encodings.sort_by_key(|(algo, qual)| (-(*qual as i64), *algo)); 419 | 420 | encodings.into_iter().next().map(|(algo, _)| algo) 421 | } 422 | } 423 | 424 | impl Encoding { 425 | #[cfg(any(feature = "brotli", feature = "gzip", feature = "zlib"))] 426 | pub fn as_str(self) -> Option<&'static str> { 427 | match self { 428 | #[cfg(feature = "brotli")] 429 | Self::Brotli => Some("br"), 430 | #[cfg(feature = "gzip")] 431 | Self::Gzip => Some("gzip"), 432 | Self::Identity => None, 433 | // The `deflate` encoding is actually zlib, but the HTTP standard calls it `deflate`. 434 | #[cfg(feature = "zlib")] 435 | Self::Zlib => Some("deflate"), 436 | } 437 | } 438 | } 439 | 440 | impl FromStr for Encoding { 441 | type Err = (); 442 | 443 | fn from_str(s: &str) -> Result { 444 | Ok(match s { 445 | #[cfg(feature = "brotli")] 446 | "br" => Encoding::Brotli, 447 | #[cfg(feature = "gzip")] 448 | "gzip" => Encoding::Gzip, 449 | "identity" => Encoding::Identity, 450 | #[cfg(feature = "zlib")] 451 | "deflate" => Encoding::Zlib, 452 | _ => return Err(()), 453 | }) 454 | } 455 | } 456 | -------------------------------------------------------------------------------- /mendes/src/forms.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | use std::{fmt, str}; 3 | 4 | pub use mendes_macros::{form, ToField}; 5 | use thiserror::Error; 6 | 7 | #[cfg(feature = "uploads")] 8 | #[cfg_attr(docsrs, doc(cfg(feature = "uploads")))] 9 | pub use crate::multipart::{from_form_data, File}; 10 | 11 | /// A data type that knows how to generate an HTML form for itself 12 | /// 13 | /// Implementations are usually generated using the `form` procedural macro attribute. 14 | pub trait ToForm { 15 | fn to_form() -> Form; 16 | } 17 | 18 | /// An HTML form 19 | pub struct Form { 20 | /// The form action (a relative URL to send the form to) 21 | pub action: Option>, 22 | /// The form data encoding type 23 | pub enctype: Option>, 24 | /// The method to use on form submission 25 | pub method: Option>, 26 | /// List of classes to set on the form element 27 | pub classes: Vec>, 28 | /// The field sets contained in this form 29 | pub sets: Vec
, 30 | } 31 | 32 | impl Form { 33 | // This should only be used by procedural macros. 34 | #[doc(hidden)] 35 | pub fn prepare(mut self) -> Self { 36 | let multipart = self 37 | .sets 38 | .iter() 39 | .flat_map(|s| &s.items) 40 | .any(|i| i.multipart()); 41 | if multipart { 42 | self.enctype = Some("multipart/form-data".into()); 43 | } 44 | self 45 | } 46 | 47 | pub fn set(mut self, name: &str, value: T) -> Result { 48 | self.sets 49 | .iter_mut() 50 | .flat_map(|s| &mut s.items) 51 | .try_fold((), |_, item| item.set(name, &value)) 52 | .map(|_| self) 53 | } 54 | } 55 | 56 | impl fmt::Display for Form { 57 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 58 | write!(fmt, " write!(fmt, "{s}")?, 73 | _ => write!(fmt, " {s}")?, 74 | } 75 | } 76 | write!(fmt, "\"")?; 77 | } 78 | write!(fmt, ">")?; 79 | for set in &self.sets { 80 | write!(fmt, "{set}")?; 81 | } 82 | write!(fmt, "") 83 | } 84 | } 85 | 86 | pub struct FieldSet { 87 | pub legend: Option<&'static str>, 88 | pub items: Vec, 89 | } 90 | 91 | impl fmt::Display for FieldSet { 92 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 93 | write!(fmt, "
")?; 94 | if let Some(s) = self.legend { 95 | write!(fmt, "{s}")?; 96 | } 97 | for item in &self.items { 98 | write!(fmt, "{item}")?; 99 | } 100 | write!(fmt, "
") 101 | } 102 | } 103 | 104 | pub struct Item { 105 | pub label: Option>, 106 | pub contents: ItemContents, 107 | } 108 | 109 | impl Item { 110 | fn set(&mut self, name: &str, value: &T) -> Result<(), Error> { 111 | match &mut self.contents { 112 | ItemContents::Single(f) => { 113 | if f.name() != Some(name) { 114 | return Ok(()); 115 | } 116 | 117 | match f { 118 | Field::Checkbox(f) => { 119 | let s = value.to_string(); 120 | if s == "true" || s == "1" { 121 | f.checked = true; 122 | Ok(()) 123 | } else if s == "false" || s == "0" { 124 | f.checked = false; 125 | Ok(()) 126 | } else { 127 | Err(Error::SetInvalidBooleanValue) 128 | } 129 | } 130 | Field::Date(f) => { 131 | f.value = Some(value.to_string().into()); 132 | Ok(()) 133 | } 134 | Field::Email(f) => { 135 | f.value = Some(value.to_string().into()); 136 | Ok(()) 137 | } 138 | Field::Hidden(f) => { 139 | f.value = Some(value.to_string().into()); 140 | Ok(()) 141 | } 142 | Field::Number(f) => { 143 | f.value = Some(value.to_string().into()); 144 | Ok(()) 145 | } 146 | Field::Password(f) => { 147 | f.value = Some(value.to_string().into()); 148 | Ok(()) 149 | } 150 | Field::Select(f) => { 151 | let val = value.to_string(); 152 | for option in &mut f.options { 153 | if option.value == val { 154 | option.selected = true; 155 | return Ok(()); 156 | } 157 | } 158 | Err(Error::SetOptionNotFound) 159 | } 160 | Field::Text(f) => { 161 | f.value = Some(value.to_string().into()); 162 | Ok(()) 163 | } 164 | Field::File(_) | Field::Submit(_) => Err(Error::SetUnsupportedFieldType), 165 | } 166 | } 167 | ItemContents::Multi(items) => { 168 | for item in items { 169 | item.set(name, value)?; 170 | } 171 | Ok(()) 172 | } 173 | } 174 | } 175 | 176 | fn multipart(&self) -> bool { 177 | match &self.contents { 178 | ItemContents::Single(f) => matches!(f, Field::File(_)), 179 | ItemContents::Multi(items) => items.iter().any(|i| i.multipart()), 180 | } 181 | } 182 | } 183 | 184 | impl fmt::Display for Item { 185 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 186 | match (&self.contents, &self.label) { 187 | (ItemContents::Single(Field::Submit(_)), None) => write!(fmt, "{}", self.contents), 188 | (ItemContents::Single(f), Some(l)) => write!( 189 | fmt, 190 | r#"{}"#, 191 | f.name().unwrap(), 192 | l, 193 | self.contents 194 | ), 195 | (_, Some(l)) => write!(fmt, r#"{}"#, l, self.contents), 196 | (_, None) => write!(fmt, "{}", self.contents), 197 | } 198 | } 199 | } 200 | 201 | pub enum ItemContents { 202 | Single(Field), 203 | Multi(Vec), 204 | } 205 | 206 | impl fmt::Display for ItemContents { 207 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 208 | match self { 209 | ItemContents::Single(f) => write!(fmt, "{f}"), 210 | ItemContents::Multi(items) => { 211 | write!(fmt, r#"
"#)?; 212 | for item in items { 213 | write!(fmt, "{item}")?; 214 | } 215 | write!(fmt, "
") 216 | } 217 | } 218 | } 219 | } 220 | 221 | pub enum Field { 222 | Checkbox(Checkbox), 223 | Date(Date), 224 | Email(Email), 225 | File(FileInput), 226 | Hidden(Hidden), 227 | Number(Number), 228 | Password(Password), 229 | Select(Select), 230 | Submit(Submit), 231 | Text(Text), 232 | } 233 | 234 | impl Field { 235 | pub fn name(&self) -> Option<&str> { 236 | use Field::*; 237 | match self { 238 | Checkbox(f) => Some(&f.name), 239 | Date(f) => Some(&f.name), 240 | Email(f) => Some(&f.name), 241 | File(f) => Some(&f.name), 242 | Hidden(f) => Some(&f.name), 243 | Number(f) => Some(&f.name), 244 | Password(f) => Some(&f.name), 245 | Select(f) => Some(&f.name), 246 | Text(f) => Some(&f.name), 247 | Submit(_) => None, 248 | } 249 | } 250 | } 251 | 252 | impl fmt::Display for Field { 253 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 254 | use Field::*; 255 | match self { 256 | Checkbox(f) => write!(fmt, "{f}"), 257 | Date(f) => write!(fmt, "{f}"), 258 | Email(f) => write!(fmt, "{f}"), 259 | File(f) => write!(fmt, "{f}"), 260 | Hidden(f) => write!(fmt, "{f}"), 261 | Number(f) => write!(fmt, "{f}"), 262 | Password(f) => write!(fmt, "{f}"), 263 | Select(f) => write!(fmt, "{f}"), 264 | Submit(f) => write!(fmt, "{f}"), 265 | Text(f) => write!(fmt, "{f}"), 266 | } 267 | } 268 | } 269 | 270 | pub struct Checkbox { 271 | pub name: Cow<'static, str>, 272 | pub checked: bool, 273 | } 274 | 275 | impl fmt::Display for Checkbox { 276 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 277 | write!( 278 | fmt, 279 | r#"") 286 | } 287 | } 288 | 289 | pub struct Date { 290 | pub name: Cow<'static, str>, 291 | pub value: Option>, 292 | } 293 | 294 | impl fmt::Display for Date { 295 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 296 | write!(fmt, r#"") 301 | } 302 | } 303 | 304 | pub struct Email { 305 | pub name: Cow<'static, str>, 306 | pub value: Option>, 307 | } 308 | 309 | impl fmt::Display for Email { 310 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 311 | write!(fmt, r#"") 316 | } 317 | } 318 | 319 | pub struct FileInput { 320 | pub name: Cow<'static, str>, 321 | } 322 | 323 | impl fmt::Display for FileInput { 324 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 325 | write!(fmt, r#""#, self.name) 326 | } 327 | } 328 | 329 | pub struct Hidden { 330 | pub name: Cow<'static, str>, 331 | pub value: Option>, 332 | } 333 | 334 | impl Hidden { 335 | fn from_params(name: Cow<'static, str>, _: &[(&str, &str)]) -> Self { 336 | Self { name, value: None } 337 | } 338 | } 339 | 340 | impl fmt::Display for Hidden { 341 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 342 | write!(fmt, r#"") 347 | } 348 | } 349 | 350 | pub struct Number { 351 | pub name: Cow<'static, str>, 352 | pub value: Option>, 353 | } 354 | 355 | impl fmt::Display for Number { 356 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 357 | write!(fmt, r#"") 362 | } 363 | } 364 | 365 | pub struct Password { 366 | pub name: Cow<'static, str>, 367 | pub value: Option>, 368 | } 369 | 370 | impl fmt::Display for Password { 371 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 372 | write!(fmt, r#"") 377 | } 378 | } 379 | 380 | pub struct Select { 381 | pub name: Cow<'static, str>, 382 | pub options: Vec, 383 | } 384 | 385 | impl fmt::Display for Select { 386 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 387 | write!(fmt, r#"") 392 | } 393 | } 394 | 395 | pub struct SelectOption { 396 | pub label: Cow<'static, str>, 397 | pub value: Cow<'static, str>, 398 | pub disabled: bool, 399 | pub selected: bool, 400 | } 401 | 402 | impl fmt::Display for SelectOption { 403 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 404 | write!(fmt, r#"", self.label) 412 | } 413 | } 414 | 415 | pub struct Submit { 416 | pub value: Option>, 417 | } 418 | 419 | impl fmt::Display for Submit { 420 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 421 | write!(fmt, r#"") 426 | } 427 | } 428 | 429 | pub struct Text { 430 | pub name: Cow<'static, str>, 431 | pub value: Option>, 432 | } 433 | 434 | impl fmt::Display for Text { 435 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 436 | write!(fmt, r#"") 441 | } 442 | } 443 | 444 | pub trait ToField { 445 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field; 446 | } 447 | 448 | impl ToField for bool { 449 | fn to_field(name: Cow<'static, str>, _: &[(&str, &str)]) -> Field { 450 | Field::Checkbox(Checkbox { 451 | name, 452 | checked: false, 453 | }) 454 | } 455 | } 456 | 457 | impl ToField for String { 458 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field { 459 | for (key, value) in params { 460 | if *key == "type" { 461 | if *value == "hidden" { 462 | return Field::Hidden(Hidden::from_params(name, params)); 463 | } else if *value == "email" { 464 | return Field::Email(Email { name, value: None }); 465 | } else if *value == "password" { 466 | return Field::Password(Password { name, value: None }); 467 | } 468 | } 469 | } 470 | Field::Text(Text { name, value: None }) 471 | } 472 | } 473 | 474 | impl ToField for Cow<'_, str> { 475 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field { 476 | for (key, value) in params { 477 | if *key == "type" { 478 | if *value == "hidden" { 479 | return Field::Hidden(Hidden::from_params(name, params)); 480 | } else if *value == "email" { 481 | return Field::Email(Email { name, value: None }); 482 | } else if *value == "password" { 483 | return Field::Password(Password { name, value: None }); 484 | } 485 | } 486 | } 487 | Field::Text(Text { name, value: None }) 488 | } 489 | } 490 | 491 | impl ToField for u8 { 492 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field { 493 | for (key, value) in params { 494 | if *key == "type" && *value == "hidden" { 495 | return Field::Hidden(Hidden::from_params(name, params)); 496 | } 497 | } 498 | Field::Number(Number { name, value: None }) 499 | } 500 | } 501 | 502 | impl ToField for u16 { 503 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field { 504 | for (key, value) in params { 505 | if *key == "type" && *value == "hidden" { 506 | return Field::Hidden(Hidden::from_params(name, params)); 507 | } 508 | } 509 | Field::Number(Number { name, value: None }) 510 | } 511 | } 512 | 513 | impl ToField for u32 { 514 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field { 515 | for (key, value) in params { 516 | if *key == "type" && *value == "hidden" { 517 | return Field::Hidden(Hidden::from_params(name, params)); 518 | } 519 | } 520 | Field::Number(Number { name, value: None }) 521 | } 522 | } 523 | 524 | impl ToField for u64 { 525 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field { 526 | for (key, value) in params { 527 | if *key == "type" && *value == "hidden" { 528 | return Field::Hidden(Hidden::from_params(name, params)); 529 | } 530 | } 531 | Field::Number(Number { name, value: None }) 532 | } 533 | } 534 | 535 | impl ToField for i32 { 536 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field { 537 | for (key, value) in params { 538 | if *key == "type" && *value == "hidden" { 539 | return Field::Hidden(Hidden::from_params(name, params)); 540 | } 541 | } 542 | Field::Number(Number { name, value: None }) 543 | } 544 | } 545 | 546 | impl ToField for i64 { 547 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field { 548 | for (key, value) in params { 549 | if *key == "type" && *value == "hidden" { 550 | return Field::Hidden(Hidden::from_params(name, params)); 551 | } 552 | } 553 | Field::Number(Number { name, value: None }) 554 | } 555 | } 556 | 557 | impl ToField for f32 { 558 | fn to_field(name: Cow<'static, str>, params: &[(&str, &str)]) -> Field { 559 | for (key, value) in params { 560 | if *key == "type" && *value == "hidden" { 561 | return Field::Hidden(Hidden::from_params(name, params)); 562 | } 563 | } 564 | Field::Number(Number { name, value: None }) 565 | } 566 | } 567 | 568 | #[cfg(feature = "chrono")] 569 | #[cfg_attr(docsrs, doc(cfg(feature = "chrono")))] 570 | impl ToField for chrono::NaiveDate { 571 | fn to_field(name: Cow<'static, str>, _: &[(&str, &str)]) -> Field { 572 | Field::Date(Date { name, value: None }) 573 | } 574 | } 575 | 576 | #[derive(Debug, Error)] 577 | pub enum Error { 578 | #[error("invalid value for boolean field")] 579 | SetInvalidBooleanValue, 580 | #[error("no option with given value found in select")] 581 | SetOptionNotFound, 582 | #[error("unable to set value for unknown field")] 583 | SetUnknownField, 584 | #[error("setting value not supported for this field type")] 585 | SetUnsupportedFieldType, 586 | } 587 | -------------------------------------------------------------------------------- /mendes/src/multipart.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{self, Display}; 2 | use std::str::{self, FromStr}; 3 | 4 | use http::HeaderMap; 5 | use memchr::memmem; 6 | use serde::de::{ 7 | DeserializeSeed, EnumAccess, Error as ErrorTrait, MapAccess, VariantAccess, Visitor, 8 | }; 9 | use serde::Deserialize; 10 | 11 | pub fn from_form_data<'a, T: Deserialize<'a>>( 12 | headers: &HeaderMap, 13 | input: &'a [u8], 14 | ) -> std::result::Result { 15 | let ctype = headers 16 | .get("content-type") 17 | .ok_or_else(|| Error::custom("content-type header not found"))? 18 | .as_bytes(); 19 | let split = 20 | memmem::find(ctype, b"; boundary=").ok_or_else(|| Error::custom("boundary not found"))?; 21 | 22 | let mut boundary = Vec::with_capacity(2 + ctype.len() - split - 11); 23 | boundary.extend(b"--"); 24 | boundary.extend(&ctype[split + 11..]); 25 | 26 | let mut deserializer = Deserializer { 27 | input, 28 | boundary, 29 | state: None, 30 | }; 31 | T::deserialize(&mut deserializer) 32 | } 33 | 34 | macro_rules! parse_value_type { 35 | ($($ty:ident => ($visit_method:ident, $deserializer_method:ident),)*) => { 36 | $( 37 | fn $deserializer_method(self, visitor: V) -> Result 38 | where V: Visitor<'de> 39 | { 40 | if let Some((State::Data, Part::Text { data, .. })) = self.state { 41 | let s = str::from_utf8(data) 42 | .map_err(|_| Error::custom("invalid input while UTF-8 decoding for $ty"))?; 43 | visitor.$visit_method( 44 | $ty::from_str(s).map_err(|_| Error::custom("unable to convert str to $ty"))?, 45 | ) 46 | } else { 47 | unreachable!() 48 | } 49 | } 50 | )* 51 | } 52 | } 53 | 54 | pub struct Deserializer<'de> { 55 | input: &'de [u8], 56 | boundary: Vec, 57 | state: Option<(State, Part<'de>)>, 58 | } 59 | 60 | impl<'de> serde::de::Deserializer<'de> for &mut Deserializer<'de> { 61 | type Error = Error; 62 | 63 | fn deserialize_any(self, _: V) -> Result 64 | where 65 | V: Visitor<'de>, 66 | { 67 | unimplemented!() 68 | } 69 | 70 | fn deserialize_identifier(self, visitor: V) -> Result 71 | where 72 | V: Visitor<'de>, 73 | { 74 | match &self.state { 75 | Some((State::Name, part)) => { 76 | let name = match part { 77 | Part::Blob { name, .. } => name, 78 | Part::Text { name, .. } => name, 79 | }; 80 | visitor.visit_borrowed_str(name) 81 | } 82 | Some((State::Filename, part)) => match part { 83 | Part::Blob { .. } => visitor.visit_borrowed_str("filename"), 84 | Part::Text { .. } => unreachable!(), 85 | }, 86 | Some((State::Type, _)) => visitor.visit_borrowed_str("type"), 87 | Some((State::Data, part)) => match part { 88 | Part::Blob { .. } => visitor.visit_borrowed_str("data"), 89 | Part::Text { .. } => self.deserialize_str(visitor), 90 | }, 91 | _ => unreachable!(), 92 | } 93 | } 94 | 95 | fn deserialize_ignored_any(self, _: V) -> Result 96 | where 97 | V: Visitor<'de>, 98 | { 99 | unimplemented!() 100 | } 101 | 102 | parse_value_type! { 103 | bool => (visit_bool, deserialize_bool), 104 | u8 => (visit_u8, deserialize_u8), 105 | u16 => (visit_u16, deserialize_u16), 106 | u32 => (visit_u32, deserialize_u32), 107 | u64 => (visit_u64, deserialize_u64), 108 | i8 => (visit_i8, deserialize_i8), 109 | i16 => (visit_i16, deserialize_i16), 110 | i32 => (visit_i32, deserialize_i32), 111 | i64 => (visit_i64, deserialize_i64), 112 | f32 => (visit_f32, deserialize_f32), 113 | f64 => (visit_f64, deserialize_f64), 114 | } 115 | 116 | fn deserialize_char(self, visitor: V) -> Result 117 | where 118 | V: Visitor<'de>, 119 | { 120 | if let Some((State::Data, Part::Text { data, .. })) = self.state { 121 | let s = str::from_utf8(data) 122 | .map_err(|_| Error::custom("invalid input while UTF-8 decoding for i32"))?; 123 | visitor.visit_char( 124 | char::from_str(s).map_err(|_| Error::custom("unable to convert str to $ty"))?, 125 | ) 126 | } else { 127 | unreachable!() 128 | } 129 | } 130 | 131 | fn deserialize_str(self, visitor: V) -> Result 132 | where 133 | V: Visitor<'de>, 134 | { 135 | match self.state.as_ref() { 136 | Some((State::Name, _)) => unreachable!(), 137 | Some((State::Filename, Part::Blob { filename, .. })) => { 138 | visitor.visit_borrowed_str(filename.as_ref().unwrap()) 139 | } 140 | Some((State::Type, Part::Blob { ctype, .. })) => { 141 | visitor.visit_borrowed_str(ctype.as_ref().unwrap()) 142 | } 143 | Some((State::Data, part)) => { 144 | let data = match part { 145 | Part::Blob { data, .. } => data, 146 | Part::Text { data, .. } => data, 147 | }; 148 | let data = str::from_utf8(data) 149 | .map_err(|_| Error::custom("error while decoding str from UTF-8"))?; 150 | visitor.visit_borrowed_str(data) 151 | } 152 | _ => unreachable!(), 153 | } 154 | } 155 | 156 | fn deserialize_string(self, visitor: V) -> Result 157 | where 158 | V: Visitor<'de>, 159 | { 160 | self.deserialize_str(visitor) 161 | } 162 | 163 | fn deserialize_bytes(self, visitor: V) -> Result 164 | where 165 | V: Visitor<'de>, 166 | { 167 | let data = match self.state.as_ref() { 168 | Some((_, Part::Blob { data, .. })) => data, 169 | Some((_, Part::Text { data, .. })) => data, 170 | None => unreachable!(), 171 | }; 172 | visitor.visit_borrowed_bytes(data) 173 | } 174 | 175 | fn deserialize_byte_buf(self, visitor: V) -> Result 176 | where 177 | V: Visitor<'de>, 178 | { 179 | let data = match self.state.as_ref() { 180 | Some((_, Part::Blob { data, .. })) => data, 181 | Some((_, Part::Text { data, .. })) => data, 182 | None => unreachable!(), 183 | }; 184 | visitor.visit_byte_buf(data.to_vec()) 185 | } 186 | 187 | fn deserialize_option(self, visitor: V) -> Result 188 | where 189 | V: Visitor<'de>, 190 | { 191 | match self.state.as_ref() { 192 | Some((State::Filename, part)) => { 193 | if let Part::Blob { 194 | filename: Some(_), .. 195 | } = part 196 | { 197 | visitor.visit_some(self) 198 | } else { 199 | visitor.visit_none() 200 | } 201 | } 202 | Some((State::Type, part)) => { 203 | if let Part::Blob { ctype: Some(_), .. } = part { 204 | visitor.visit_some(self) 205 | } else { 206 | visitor.visit_none() 207 | } 208 | } 209 | _ => unreachable!(), 210 | } 211 | } 212 | 213 | fn deserialize_unit(self, _: V) -> Result 214 | where 215 | V: Visitor<'de>, 216 | { 217 | unimplemented!() 218 | } 219 | 220 | fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result 221 | where 222 | V: Visitor<'de>, 223 | { 224 | self.deserialize_unit(visitor) 225 | } 226 | 227 | fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result 228 | where 229 | V: Visitor<'de>, 230 | { 231 | visitor.visit_newtype_struct(self) 232 | } 233 | 234 | fn deserialize_seq(self, _: V) -> Result 235 | where 236 | V: Visitor<'de>, 237 | { 238 | unreachable!() 239 | } 240 | 241 | fn deserialize_tuple(self, _len: usize, visitor: V) -> Result 242 | where 243 | V: Visitor<'de>, 244 | { 245 | self.deserialize_seq(visitor) 246 | } 247 | 248 | // Tuple structs look just like sequences in JSON. 249 | fn deserialize_tuple_struct( 250 | self, 251 | _name: &'static str, 252 | _len: usize, 253 | visitor: V, 254 | ) -> Result 255 | where 256 | V: Visitor<'de>, 257 | { 258 | self.deserialize_seq(visitor) 259 | } 260 | 261 | fn deserialize_map(mut self, visitor: V) -> Result 262 | where 263 | V: Visitor<'de>, 264 | { 265 | visitor.visit_map(&mut self) 266 | } 267 | 268 | fn deserialize_struct( 269 | self, 270 | _name: &'static str, 271 | _fields: &'static [&'static str], 272 | visitor: V, 273 | ) -> Result 274 | where 275 | V: Visitor<'de>, 276 | { 277 | self.deserialize_map(visitor) 278 | } 279 | 280 | fn deserialize_enum( 281 | self, 282 | _name: &'static str, 283 | _variants: &'static [&'static str], 284 | visitor: V, 285 | ) -> Result 286 | where 287 | V: Visitor<'de>, 288 | { 289 | visitor.visit_enum(Enum { de: self }) 290 | } 291 | } 292 | 293 | // Note that we have maps at two levels: the top level as well as the fields 294 | // inside a `File` object (`Part::Blob` variant). This is especially relevant 295 | // when deciding to return `Ok(None)` from `next_key_seed()`. 296 | impl<'de> MapAccess<'de> for &mut Deserializer<'de> { 297 | type Error = Error; 298 | 299 | fn next_key_seed(&mut self, seed: K) -> Result> 300 | where 301 | K: DeserializeSeed<'de>, 302 | { 303 | let split_len = self.boundary.len(); 304 | if self.state.is_none() && self.input.starts_with(&self.boundary) { 305 | if &self.input[split_len..split_len + 4] == b"--\r\n" { 306 | return Ok(None); 307 | } 308 | 309 | let (len, part) = Part::from_bytes(&self.input[split_len + 2..], &self.boundary)?; 310 | self.state = Some((State::Name, part)); 311 | self.input = &self.input[split_len + 2 + len..]; 312 | let res = seed.deserialize(&mut **self).map(Some); 313 | self.state = match self.state.take() { 314 | Some((_, part @ Part::Blob { .. })) => Some((State::Filename, part)), 315 | Some((_, part @ Part::Text { .. })) => Some((State::Data, part)), 316 | None => unreachable!(), 317 | }; 318 | res 319 | } else if let Some((state, part)) = &self.state { 320 | match state { 321 | State::Name => seed.deserialize(&mut **self).map(Some), 322 | State::Filename => match part { 323 | Part::Blob { .. } => seed.deserialize(&mut **self).map(Some), 324 | Part::Text { .. } => Ok(None), 325 | }, 326 | State::Type => seed.deserialize(&mut **self).map(Some), 327 | State::Data => seed.deserialize(&mut **self).map(Some), 328 | State::End => { 329 | self.state = None; 330 | Ok(None) 331 | } 332 | } 333 | } else { 334 | unreachable!() 335 | } 336 | } 337 | 338 | fn next_value_seed(&mut self, seed: V) -> Result 339 | where 340 | V: DeserializeSeed<'de>, 341 | { 342 | let res = seed.deserialize(&mut **self); 343 | self.state = match self.state.take() { 344 | Some((State::Name, _)) => unreachable!(), 345 | Some((State::Filename, part)) => Some((State::Type, part)), 346 | Some((State::Type, part)) => Some((State::Data, part)), 347 | Some((State::Data, part)) => match part { 348 | Part::Blob { .. } => Some((State::End, part)), 349 | Part::Text { .. } => None, 350 | }, 351 | Some((State::End, _)) => unreachable!(), 352 | None => None, 353 | }; 354 | res 355 | } 356 | } 357 | 358 | struct Enum<'a, 'de: 'a> { 359 | de: &'a mut Deserializer<'de>, 360 | } 361 | 362 | impl<'de> EnumAccess<'de> for Enum<'_, 'de> { 363 | type Error = Error; 364 | type Variant = Self; 365 | 366 | fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> 367 | where 368 | V: DeserializeSeed<'de>, 369 | { 370 | Ok((seed.deserialize(&mut *self.de)?, self)) 371 | } 372 | } 373 | 374 | impl<'de> VariantAccess<'de> for Enum<'_, 'de> { 375 | type Error = Error; 376 | 377 | fn unit_variant(self) -> Result<()> { 378 | Ok(()) 379 | } 380 | 381 | fn newtype_variant_seed(self, seed: T) -> Result 382 | where 383 | T: DeserializeSeed<'de>, 384 | { 385 | seed.deserialize(self.de) 386 | } 387 | 388 | fn tuple_variant(self, _len: usize, _visitor: V) -> Result 389 | where 390 | V: Visitor<'de>, 391 | { 392 | unimplemented!() 393 | } 394 | 395 | fn struct_variant(self, _fields: &'static [&'static str], _visitor: V) -> Result 396 | where 397 | V: Visitor<'de>, 398 | { 399 | unimplemented!() 400 | } 401 | } 402 | 403 | #[derive(Debug)] 404 | enum Part<'a> { 405 | Blob { 406 | name: &'a str, 407 | filename: Option<&'a str>, 408 | ctype: Option<&'a str>, 409 | data: &'a [u8], 410 | }, 411 | Text { 412 | name: &'a str, 413 | data: &'a [u8], 414 | }, 415 | } 416 | 417 | #[derive(Debug)] 418 | enum State { 419 | Name, 420 | Filename, 421 | Type, 422 | Data, 423 | End, 424 | } 425 | 426 | impl<'a> Part<'a> { 427 | fn from_bytes(bytes: &'a [u8], boundary: &[u8]) -> Result<(usize, Self)> { 428 | let mut header_buf = [httparse::EMPTY_HEADER; 4]; 429 | let status = httparse::parse_headers(bytes, &mut header_buf) 430 | .map_err(|_| Error::custom("unable to parse part headers"))?; 431 | let (header_len, headers) = if let httparse::Status::Complete((len, headers)) = status { 432 | (len, headers) 433 | } else { 434 | unreachable!(); 435 | }; 436 | 437 | let (mut name, mut filename, mut ctype) = (None, None, None); 438 | for header in headers { 439 | let value = str::from_utf8(header.value) 440 | .map_err(|_| Error::custom("error while decoding UTF-8 from header value"))?; 441 | let header = header.name.to_string().to_ascii_lowercase(); 442 | if header == "content-disposition" { 443 | for param in value.split(';') { 444 | if param.trim() == "form-data" { 445 | continue; 446 | } 447 | 448 | let sep = param 449 | .find('=') 450 | .ok_or_else(|| Error::custom("parameter value not found"))?; 451 | let pname = ¶m[..sep].trim(); 452 | let value = ¶m[sep + 2..param.len() - 1]; 453 | if *pname == "name" { 454 | name = Some(value); 455 | } else if *pname == "filename" { 456 | filename = Some(value); 457 | } 458 | } 459 | } else if header == "content-type" { 460 | ctype = Some(value); 461 | } 462 | } 463 | 464 | let (len, data) = if let Some(pos) = memmem::find(bytes, boundary) { 465 | (pos, &bytes[header_len..pos - 2]) 466 | } else { 467 | (bytes.len(), &bytes[header_len..]) 468 | }; 469 | 470 | let name = name.ok_or_else(|| Error::custom("no name found"))?; 471 | let part = match &filename { 472 | Some(_) => Part::Blob { 473 | name, 474 | filename, 475 | ctype, 476 | data, 477 | }, 478 | None => Part::Text { name, data }, 479 | }; 480 | Ok((len, part)) 481 | } 482 | } 483 | 484 | #[derive(Clone, Debug, PartialEq, Eq)] 485 | pub enum Error { 486 | Message(String), 487 | } 488 | 489 | impl serde::de::Error for Error { 490 | fn custom(msg: T) -> Self { 491 | Error::Message(msg.to_string()) 492 | } 493 | } 494 | 495 | impl Display for Error { 496 | fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 497 | match self { 498 | Error::Message(msg) => formatter.write_str(msg), 499 | } 500 | } 501 | } 502 | 503 | impl std::error::Error for Error {} 504 | 505 | type Result = std::result::Result; 506 | 507 | #[derive(Deserialize)] 508 | pub struct File<'a> { 509 | #[serde(rename = "type")] 510 | pub ctype: Option<&'a str>, 511 | pub filename: Option<&'a str>, 512 | pub data: &'a [u8], 513 | } 514 | 515 | impl super::forms::ToField for File<'_> { 516 | fn to_field(name: std::borrow::Cow<'static, str>, _: &[(&str, &str)]) -> super::forms::Field { 517 | super::forms::Field::File(super::forms::FileInput { name }) 518 | } 519 | } 520 | 521 | #[cfg(feature = "uploads")] 522 | #[cfg(test)] 523 | mod tests { 524 | use super::*; 525 | use http::HeaderMap; 526 | use std::convert::TryInto; 527 | 528 | #[test] 529 | fn upload() { 530 | let ctype = "multipart/form-data; boundary=---------------------------200426345241597222021292378679"; 531 | let body = [ 532 | 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 533 | 45, 45, 45, 45, 45, 45, 45, 50, 48, 48, 52, 50, 54, 51, 52, 53, 50, 52, 49, 53, 57, 55, 534 | 50, 50, 50, 48, 50, 49, 50, 57, 50, 51, 55, 56, 54, 55, 57, 13, 10, 67, 111, 110, 116, 535 | 101, 110, 116, 45, 68, 105, 115, 112, 111, 115, 105, 116, 105, 111, 110, 58, 32, 102, 536 | 111, 114, 109, 45, 100, 97, 116, 97, 59, 32, 110, 97, 109, 101, 61, 34, 102, 105, 108, 537 | 101, 34, 59, 32, 102, 105, 108, 101, 110, 97, 109, 101, 61, 34, 105, 49, 56, 110, 34, 538 | 13, 10, 67, 111, 110, 116, 101, 110, 116, 45, 84, 121, 112, 101, 58, 32, 97, 112, 112, 539 | 108, 105, 99, 97, 116, 105, 111, 110, 47, 111, 99, 116, 101, 116, 45, 115, 116, 114, 540 | 101, 97, 109, 13, 10, 13, 10, 73, 195, 177, 116, 195, 171, 114, 110, 195, 162, 116, 541 | 105, 195, 180, 110, 195, 160, 108, 105, 122, 195, 166, 116, 105, 195, 184, 110, 34, 10, 542 | 13, 10, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 543 | 45, 45, 45, 45, 45, 45, 45, 45, 45, 50, 48, 48, 52, 50, 54, 51, 52, 53, 50, 52, 49, 53, 544 | 57, 55, 50, 50, 50, 48, 50, 49, 50, 57, 50, 51, 55, 56, 54, 55, 57, 13, 10, 67, 111, 545 | 110, 116, 101, 110, 116, 45, 68, 105, 115, 112, 111, 115, 105, 116, 105, 111, 110, 58, 546 | 32, 102, 111, 114, 109, 45, 100, 97, 116, 97, 59, 32, 110, 97, 109, 101, 61, 34, 97, 547 | 115, 115, 101, 116, 34, 13, 10, 13, 10, 50, 13, 10, 45, 45, 45, 45, 45, 45, 45, 45, 45, 548 | 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 50, 48, 549 | 48, 52, 50, 54, 51, 52, 53, 50, 52, 49, 53, 57, 55, 50, 50, 50, 48, 50, 49, 50, 57, 50, 550 | 51, 55, 56, 54, 55, 57, 45, 45, 13, 10, 551 | ]; 552 | 553 | let mut headers = HeaderMap::new(); 554 | headers.insert("content-type", ctype.try_into().unwrap()); 555 | let form = from_form_data::
(&headers, &body).unwrap(); 556 | assert_eq!(form.file.filename, Some("i18n")); 557 | assert_eq!(form.file.ctype, Some("application/octet-stream")); 558 | assert_eq!( 559 | form.file.data, 560 | b"I\xc3\xb1t\xc3\xabrn\xc3\xa2ti\xc3\xb4n\xc3\xa0liz\xc3\xa6ti\xc3\xb8n\"\n" 561 | ); 562 | assert_eq!(form.asset, 2); 563 | } 564 | 565 | #[derive(Deserialize)] 566 | struct Form<'a> { 567 | #[serde(borrow)] 568 | file: File<'a>, 569 | asset: i32, 570 | } 571 | 572 | #[test] 573 | fn enum_field() { 574 | let ctype = "multipart/form-data; boundary=---------------------------345106847831590504122057183932"; 575 | let body = "-----------------------------345106847831590504122057183932\r 576 | Content-Disposition: form-data; name=\"foo\"\r 577 | \r 578 | Foo\r 579 | -----------------------------345106847831590504122057183932\r 580 | Content-Disposition: form-data; name=\"val\"\r 581 | \r 582 | 1\r 583 | -----------------------------345106847831590504122057183932--\r\n"; 584 | 585 | let mut headers = HeaderMap::new(); 586 | headers.insert("content-type", ctype.try_into().unwrap()); 587 | let form = from_form_data::(&headers, body.as_bytes()).unwrap(); 588 | assert_eq!(form.foo, FooBar::Foo); 589 | assert_eq!(form.val, 1); 590 | } 591 | 592 | #[derive(Deserialize)] 593 | struct EnumForm { 594 | foo: FooBar, 595 | val: i32, 596 | } 597 | 598 | #[derive(Debug, Deserialize, PartialEq)] 599 | enum FooBar { 600 | Foo, 601 | Bar, 602 | } 603 | } 604 | -------------------------------------------------------------------------------- /mendes/src/application.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Cow; 2 | #[cfg(feature = "body-util")] 3 | use std::error::Error as StdError; 4 | use std::str::FromStr; 5 | use std::sync::Arc; 6 | 7 | use async_trait::async_trait; 8 | #[cfg(feature = "body-util")] 9 | use bytes::Bytes; 10 | use http::header::LOCATION; 11 | use http::request::Parts; 12 | use http::Request; 13 | use http::{Response, StatusCode}; 14 | use http_body::Body as HttpBody; 15 | use percent_encoding::percent_decode_str; 16 | use thiserror::Error; 17 | 18 | pub use mendes_macros::{handler, route, scope}; 19 | 20 | /// Main interface for an application or service 21 | /// 22 | /// The `Application` holds state and routes request to the proper handlers. A handler gets 23 | /// an immutable reference to the `Application` to access application state. Common usage 24 | /// for this would be to hold a persistent storage connection pool. 25 | /// 26 | /// The `Application` also helps process handler reponses. It can handle errors and turn them 27 | /// into HTTP responses, using the `error()` method to transform the `Error` associated type 28 | /// into a `Response`. 29 | #[async_trait] 30 | pub trait Application: Send + Sized { 31 | type RequestBody: Send; 32 | type ResponseBody: HttpBody; 33 | type Error: IntoResponse + WithStatus + From + Send; 34 | 35 | async fn handle(cx: Context) -> Response; 36 | 37 | fn from_query<'a, T: serde::Deserialize<'a>>(req: &'a Parts) -> Result { 38 | let query = req.uri.query().ok_or(Error::QueryMissing)?; 39 | let data = 40 | serde_urlencoded::from_bytes::(query.as_bytes()).map_err(Error::QueryDecode)?; 41 | Ok(data) 42 | } 43 | 44 | fn from_body_bytes<'de, T: serde::de::Deserialize<'de>>( 45 | req: &Parts, 46 | bytes: &'de [u8], 47 | ) -> Result { 48 | from_bytes::(req, bytes) 49 | } 50 | 51 | #[cfg(feature = "body-util")] 52 | #[cfg_attr(docsrs, doc(cfg(feature = "body-util")))] 53 | async fn from_body( 54 | req: &Parts, 55 | body: Self::RequestBody, 56 | max_len: usize, 57 | ) -> Result 58 | where 59 | Self::RequestBody: HttpBody + Send, 60 | ::Data: Send, 61 | ::Error: Into>, 62 | { 63 | // Check if the Content-Length header suggests the body is larger than our max len 64 | // to avoid allocation if we drop the request in any case. 65 | let expected_len = match body.size_hint().upper() { 66 | Some(length) => length, 67 | None => body.size_hint().lower(), 68 | }; 69 | 70 | if expected_len > max_len as u64 { 71 | return Err(Error::BodyTooLarge); 72 | } 73 | 74 | from_body::(req, body, max_len).await 75 | } 76 | 77 | #[cfg(feature = "body-util")] 78 | #[cfg_attr(docsrs, doc(cfg(feature = "body-util")))] 79 | async fn body_bytes(body: B, max_len: usize) -> Result 80 | where 81 | B::Data: Send, 82 | B::Error: Into>, 83 | { 84 | // Check if the Content-Length header suggests the body is larger than our max len 85 | // to avoid allocation if we drop the request in any case. 86 | let expected_len = match body.size_hint().upper() { 87 | Some(length) => length, 88 | None => body.size_hint().lower(), 89 | }; 90 | 91 | if expected_len > max_len as u64 { 92 | return Err(Error::BodyTooLarge); 93 | } 94 | 95 | Ok(to_bytes(body, max_len).await?) 96 | } 97 | 98 | fn redirect(status: StatusCode, path: impl AsRef) -> Response 99 | where 100 | Self::ResponseBody: Default, 101 | { 102 | http::Response::builder() 103 | .status(status) 104 | .header(LOCATION, path.as_ref()) 105 | .body(Self::ResponseBody::default()) 106 | .unwrap() 107 | } 108 | } 109 | 110 | pub trait WithStatus {} 111 | 112 | impl WithStatus for T where StatusCode: for<'a> From<&'a T> {} 113 | 114 | pub trait IntoResponse { 115 | fn into_response(self, app: &A, req: &Parts) -> Response; 116 | } 117 | 118 | impl IntoResponse for Response { 119 | fn into_response(self, _: &A, _: &Parts) -> Response { 120 | self 121 | } 122 | } 123 | 124 | impl, E: IntoResponse> IntoResponse for Result { 125 | fn into_response(self, app: &A, req: &Parts) -> Response { 126 | match self { 127 | Ok(rsp) => rsp.into_response(app, req), 128 | Err(e) => e.into_response(app, req), 129 | } 130 | } 131 | } 132 | 133 | impl IntoResponse for Error { 134 | fn into_response(self, app: &A, req: &Parts) -> Response { 135 | A::Error::from(self).into_response(app, req) 136 | } 137 | } 138 | 139 | /// Maintains state during the routing of requests to the selected handler 140 | /// 141 | /// The `Context` is created by the `Server` (or similar code) from a `Request` and 142 | /// reference-counted `Application` instance. It is used to yield parts of the request 143 | /// to a handler or routing context through implementations of the `FromContext` trait. 144 | /// To this end, it immediately decouples the request's headers from its body, because 145 | /// the former are kept alive throughout the request while the body may be ignored 146 | /// for HEAD/GET requests or will be asynchronously consumed by the handler if necessary. 147 | /// 148 | /// Once the request reaches a destination handler, it will typically be destructed into 149 | /// its (remaining) constituent parts for further use by the handler's code. (This is usually 150 | /// taken care of by one of the handler family of procedural macros, like `get`.) 151 | pub struct Context 152 | where 153 | A: Application, 154 | { 155 | pub app: Arc, 156 | pub req: http::request::Parts, 157 | #[doc(hidden)] 158 | pub body: Option, 159 | #[doc(hidden)] 160 | pub path: PathState, 161 | } 162 | 163 | impl Context 164 | where 165 | A: Application, 166 | { 167 | // This should only be used by procedural routing macros. 168 | #[doc(hidden)] 169 | pub fn new(app: Arc, req: Request) -> Context { 170 | let path = PathState::new(req.uri().path()); 171 | let (req, body) = req.into_parts(); 172 | Context { 173 | app, 174 | req, 175 | body: Some(body), 176 | path, 177 | } 178 | } 179 | 180 | // This should only be used by procedural routing macros. 181 | #[doc(hidden)] 182 | pub fn path(&mut self) -> Option> { 183 | path_str(&self.req, &mut self.path).ok().flatten() 184 | } 185 | 186 | // This should only be used by procedural routing macros. 187 | #[doc(hidden)] 188 | pub fn rewind(&mut self) { 189 | self.path.rewind(); 190 | } 191 | 192 | // This should only be used by procedural routing macros. 193 | #[doc(hidden)] 194 | pub fn take_body(&mut self) -> Option { 195 | self.body.take() 196 | } 197 | 198 | // This should only be used by procedural routing macros. 199 | #[doc(hidden)] 200 | pub fn app(&self) -> &Arc { 201 | &self.app 202 | } 203 | 204 | // This should only be used by procedural routing macros. 205 | #[doc(hidden)] 206 | pub fn method(&self) -> &http::Method { 207 | &self.req.method 208 | } 209 | 210 | // This should only be used by procedural routing macros. 211 | #[doc(hidden)] 212 | pub fn uri(&self) -> &http::uri::Uri { 213 | &self.req.uri 214 | } 215 | 216 | // This should only be used by procedural routing macros. 217 | #[doc(hidden)] 218 | pub fn headers(&self) -> &http::HeaderMap { 219 | &self.req.headers 220 | } 221 | } 222 | 223 | impl AsMut> for Context { 224 | fn as_mut(&mut self) -> &mut Context { 225 | self 226 | } 227 | } 228 | 229 | pub trait FromContext<'a, A>: Sized 230 | where 231 | A: Application, 232 | { 233 | fn from_context( 234 | app: &'a Arc, 235 | req: &'a Parts, 236 | state: &mut PathState, 237 | body: &mut Option, 238 | ) -> Result; 239 | } 240 | 241 | macro_rules! from_context_from_str { 242 | ($self:ty) => { 243 | impl<'a, A: Application> FromContext<'a, A> for $self { 244 | fn from_context( 245 | _: &'a Arc, 246 | req: &'a Parts, 247 | state: &mut PathState, 248 | _: &mut Option, 249 | ) -> Result { 250 | let s = state 251 | .next(req.uri.path()) 252 | .ok_or(Error::PathComponentMissing.into())?; 253 | <$self>::from_str(s).map_err(|_| Error::PathParse.into()) 254 | } 255 | } 256 | 257 | impl<'a, A: Application> FromContext<'a, A> for Option<$self> { 258 | fn from_context( 259 | _: &'a Arc, 260 | req: &'a Parts, 261 | state: &mut PathState, 262 | _: &mut Option, 263 | ) -> Result { 264 | match state.next(req.uri.path()) { 265 | Some(s) => match <$self>::from_str(s) { 266 | Ok(v) => Ok(Some(v)), 267 | Err(_) => Err(Error::PathParse.into()), 268 | }, 269 | None => Ok(None), 270 | } 271 | } 272 | } 273 | }; 274 | } 275 | 276 | impl<'a, A: Application> FromContext<'a, A> for &'a A { 277 | fn from_context( 278 | app: &'a Arc, 279 | _: &'a Parts, 280 | _: &mut PathState, 281 | _: &mut Option, 282 | ) -> Result { 283 | Ok(app) 284 | } 285 | } 286 | 287 | impl<'a, A: Application> FromContext<'a, A> for &'a Arc { 288 | fn from_context( 289 | app: &'a Arc, 290 | _: &'a Parts, 291 | _: &mut PathState, 292 | _: &mut Option, 293 | ) -> Result { 294 | Ok(app) 295 | } 296 | } 297 | 298 | impl<'a, A: Application> FromContext<'a, A> for &'a http::request::Parts { 299 | fn from_context( 300 | _: &'a Arc, 301 | req: &'a Parts, 302 | _: &mut PathState, 303 | _: &mut Option, 304 | ) -> Result { 305 | Ok(req) 306 | } 307 | } 308 | 309 | impl<'a, A: Application> FromContext<'a, A> for Option<&'a [u8]> { 310 | fn from_context( 311 | _: &'a Arc, 312 | req: &'a Parts, 313 | state: &mut PathState, 314 | _: &mut Option, 315 | ) -> Result { 316 | Ok(state.next(req.uri.path()).map(|s| s.as_bytes())) 317 | } 318 | } 319 | 320 | impl<'a, A: Application> FromContext<'a, A> for &'a [u8] { 321 | fn from_context( 322 | _: &'a Arc, 323 | req: &'a Parts, 324 | state: &mut PathState, 325 | _: &mut Option, 326 | ) -> Result { 327 | state 328 | .next(req.uri.path()) 329 | .ok_or_else(|| Error::PathComponentMissing.into()) 330 | .map(|s| s.as_bytes()) 331 | } 332 | } 333 | 334 | impl<'a, A: Application> FromContext<'a, A> for Option> { 335 | fn from_context( 336 | _: &'a Arc, 337 | req: &'a Parts, 338 | state: &mut PathState, 339 | _: &mut Option, 340 | ) -> Result { 341 | Ok(path_str(req, state)?) 342 | } 343 | } 344 | 345 | impl<'a, A: Application> FromContext<'a, A> for Cow<'a, str> { 346 | fn from_context( 347 | _: &'a Arc, 348 | req: &'a Parts, 349 | state: &mut PathState, 350 | _: &mut Option, 351 | ) -> Result { 352 | match path_str(req, state)? { 353 | Some(s) => Ok(s), 354 | None => Err(Error::PathComponentMissing.into()), 355 | } 356 | } 357 | } 358 | 359 | impl<'a, A: Application> FromContext<'a, A> for Option { 360 | fn from_context( 361 | _: &'a Arc, 362 | req: &'a Parts, 363 | state: &mut PathState, 364 | _: &mut Option, 365 | ) -> Result { 366 | Ok(path_str(req, state)?.map(|s| s.into_owned())) 367 | } 368 | } 369 | 370 | impl<'a, A: Application> FromContext<'a, A> for String { 371 | fn from_context( 372 | _: &'a Arc, 373 | req: &'a Parts, 374 | state: &mut PathState, 375 | _: &mut Option, 376 | ) -> Result { 377 | match path_str(req, state)? { 378 | Some(s) => Ok(s.into_owned()), 379 | None => Err(Error::PathComponentMissing.into()), 380 | } 381 | } 382 | } 383 | 384 | fn path_str<'a>(req: &'a Parts, state: &mut PathState) -> Result>, Error> { 385 | let s = match state.next(req.uri.path()) { 386 | Some(s) => s, 387 | None => return Ok(None), 388 | }; 389 | 390 | percent_decode_str(s) 391 | .decode_utf8() 392 | .map(Some) 393 | .map_err(|_| Error::PathDecode) 394 | } 395 | 396 | from_context_from_str!(bool); 397 | from_context_from_str!(char); 398 | from_context_from_str!(f32); 399 | from_context_from_str!(f64); 400 | from_context_from_str!(i8); 401 | from_context_from_str!(i16); 402 | from_context_from_str!(i32); 403 | from_context_from_str!(i64); 404 | from_context_from_str!(i128); 405 | from_context_from_str!(isize); 406 | from_context_from_str!(u8); 407 | from_context_from_str!(u16); 408 | from_context_from_str!(u32); 409 | from_context_from_str!(u64); 410 | from_context_from_str!(u128); 411 | from_context_from_str!(usize); 412 | 413 | macro_rules! deserialize_body { 414 | ($req:ident, $bytes:ident) => {{ 415 | let content_type = $req.headers.get("content-type").ok_or(Error::BodyNoType)?; 416 | let ct_str = content_type.to_str().map_err(|_| { 417 | Error::BodyUnknownType(String::from_utf8_lossy(content_type.as_bytes()).into_owned()) 418 | })?; 419 | 420 | let mut parts = ct_str.splitn(2, ';'); 421 | match parts.next().map(|s| s.trim()) { 422 | Some("application/x-www-form-urlencoded") => { 423 | serde_urlencoded::from_bytes::(&$bytes).map_err(Error::BodyDecodeForm) 424 | } 425 | #[cfg(feature = "json")] 426 | Some("application/json") => { 427 | serde_json::from_slice::(&$bytes).map_err(Error::BodyDecodeJson) 428 | } 429 | #[cfg(feature = "uploads")] 430 | Some("multipart/form-data") => { 431 | crate::forms::from_form_data::(&$req.headers, &$bytes) 432 | .map_err(Error::BodyDecodeMultipart) 433 | } 434 | Some(_) | None => Err(Error::BodyUnknownType(ct_str.to_owned())), 435 | } 436 | }}; 437 | } 438 | 439 | #[doc(hidden)] 440 | pub struct Rest(pub T); 441 | 442 | impl<'a, A: Application> FromContext<'a, A> for Rest<&'a [u8]> { 443 | fn from_context( 444 | _: &'a Arc, 445 | req: &'a Parts, 446 | state: &mut PathState, 447 | _: &mut Option, 448 | ) -> Result { 449 | Ok(Rest(state.rest(req.uri.path()).as_bytes())) 450 | } 451 | } 452 | 453 | impl<'a, A: Application> FromContext<'a, A> for Rest> { 454 | fn from_context( 455 | _: &'a Arc, 456 | req: &'a Parts, 457 | state: &mut PathState, 458 | _: &mut Option, 459 | ) -> Result { 460 | Ok(Rest( 461 | percent_decode_str(state.rest(req.uri.path())) 462 | .decode_utf8() 463 | .map_err(|_| Error::PathDecode)?, 464 | )) 465 | } 466 | } 467 | 468 | #[doc(hidden)] 469 | pub struct Query(pub T); 470 | 471 | impl<'de, 'a: 'de, A: Application, T> FromContext<'a, A> for Query 472 | where 473 | T: serde::Deserialize<'de>, 474 | { 475 | fn from_context( 476 | _: &'a Arc, 477 | req: &'a Parts, 478 | _: &mut PathState, 479 | _: &mut Option, 480 | ) -> Result { 481 | A::from_query(req).map(Query) 482 | } 483 | } 484 | 485 | #[cfg(feature = "body-util")] 486 | #[cfg_attr(docsrs, doc(cfg(feature = "body-util")))] 487 | async fn from_body( 488 | req: &Parts, 489 | body: B, 490 | max_len: usize, 491 | ) -> Result 492 | where 493 | B: HttpBody, 494 | B::Error: Into>, 495 | { 496 | let bytes = to_bytes(body, max_len).await?; 497 | deserialize_body!(req, bytes) 498 | } 499 | 500 | fn from_bytes<'de, T: serde::de::Deserialize<'de>>( 501 | req: &Parts, 502 | bytes: &'de [u8], 503 | ) -> Result { 504 | deserialize_body!(req, bytes) 505 | } 506 | 507 | #[cfg(feature = "body-util")] 508 | #[cfg_attr(docsrs, doc(cfg(feature = "body-util")))] 509 | #[cfg_attr(feature = "tracing", tracing::instrument(skip(body)))] 510 | async fn to_bytes(body: B, max_len: usize) -> Result 511 | where 512 | B::Error: Into>, 513 | { 514 | #[cfg(feature = "body-util")] 515 | use http_body_util::BodyExt; 516 | 517 | let limited = http_body_util::Limited::new(body, max_len); 518 | match limited.collect().await { 519 | Ok(collected) => Ok(collected.to_bytes()), 520 | Err(err) => Err(Error::BodyReceive(err)), 521 | } 522 | } 523 | 524 | // This should only be used by procedural routing macros. 525 | #[doc(hidden)] 526 | pub struct PathState { 527 | prev: Option, 528 | next: Option, 529 | } 530 | 531 | impl PathState { 532 | fn new(path: &str) -> Self { 533 | let next = if path.is_empty() || path == "/" { 534 | None 535 | } else if path.find('/') == Some(0) { 536 | Some(1) 537 | } else { 538 | Some(0) 539 | }; 540 | Self { prev: None, next } 541 | } 542 | 543 | // This should only be used by procedural routing macros. 544 | #[doc(hidden)] 545 | pub fn next<'r>(&mut self, path: &'r str) -> Option<&'r str> { 546 | let start = match self.next.as_ref() { 547 | Some(v) => *v, 548 | None => return None, 549 | }; 550 | 551 | let path = &path[start..]; 552 | if path.is_empty() { 553 | self.prev = self.next.take(); 554 | return None; 555 | } 556 | 557 | match path.find('/') { 558 | Some(end) => { 559 | self.prev = self.next.replace(start + end + 1); 560 | Some(&path[..end]) 561 | } 562 | None => { 563 | self.prev = self.next.take(); 564 | Some(path) 565 | } 566 | } 567 | } 568 | 569 | // This should only be used by procedural routing macros. 570 | #[doc(hidden)] 571 | pub fn rest<'r>(&mut self, path: &'r str) -> &'r str { 572 | let start = match self.next.take() { 573 | Some(v) => v, 574 | None => return "", 575 | }; 576 | 577 | self.prev = Some(start); 578 | &path[start..] 579 | } 580 | 581 | // This should only be used by procedural routing macros. 582 | #[doc(hidden)] 583 | pub fn rewind(&mut self) { 584 | self.next = self.prev.take(); 585 | } 586 | } 587 | 588 | #[derive(Debug, Error)] 589 | pub enum Error { 590 | #[error("method not allowed")] 591 | MethodNotAllowed, 592 | #[error("no matching routes")] 593 | PathNotFound, 594 | #[error("missing path component")] 595 | PathComponentMissing, 596 | #[error("unable to parse path component")] 597 | PathParse, 598 | #[error("unable to decode UTF-8 from path component")] 599 | PathDecode, 600 | #[error("no query in request URL")] 601 | QueryMissing, 602 | #[error("unable to decode request URI query: {0}")] 603 | QueryDecode(serde_urlencoded::de::Error), 604 | #[cfg(feature = "body-util")] 605 | #[error("unable to receive request body: {0}")] 606 | BodyReceive(Box), 607 | #[cfg(feature = "body-util")] 608 | #[error("request body too large")] 609 | BodyTooLarge, 610 | #[cfg(feature = "json")] 611 | #[error("unable to decode body as JSON: {0}")] 612 | BodyDecodeJson(#[from] serde_json::Error), 613 | #[error("unable to decode body as form data: {0}")] 614 | BodyDecodeForm(serde_urlencoded::de::Error), 615 | #[cfg(feature = "uploads")] 616 | #[error("unable to decode body as multipart form data: {0}")] 617 | BodyDecodeMultipart(#[from] crate::multipart::Error), 618 | #[error("content type on request body unknown: {0}")] 619 | BodyUnknownType(String), 620 | #[error("no content type on request body")] 621 | BodyNoType, 622 | #[cfg(feature = "static")] 623 | #[error("file not found")] 624 | FileNotFound, 625 | } 626 | 627 | impl From<&Error> for StatusCode { 628 | fn from(e: &Error) -> StatusCode { 629 | use Error::*; 630 | match e { 631 | MethodNotAllowed => StatusCode::METHOD_NOT_ALLOWED, 632 | QueryMissing | QueryDecode(_) | BodyNoType => StatusCode::BAD_REQUEST, 633 | BodyUnknownType(_) => StatusCode::UNSUPPORTED_MEDIA_TYPE, 634 | PathNotFound | PathComponentMissing | PathParse | PathDecode => StatusCode::NOT_FOUND, 635 | #[cfg(feature = "body-util")] 636 | BodyReceive(_) => StatusCode::INTERNAL_SERVER_ERROR, 637 | #[cfg(feature = "body-util")] 638 | BodyTooLarge => StatusCode::BAD_REQUEST, 639 | BodyDecodeForm(_) => StatusCode::UNPROCESSABLE_ENTITY, 640 | #[cfg(feature = "json")] 641 | BodyDecodeJson(_) => StatusCode::UNPROCESSABLE_ENTITY, 642 | #[cfg(feature = "uploads")] 643 | BodyDecodeMultipart(_) => StatusCode::UNPROCESSABLE_ENTITY, 644 | #[cfg(feature = "static")] 645 | FileNotFound => StatusCode::NOT_FOUND, 646 | } 647 | } 648 | } 649 | --------------------------------------------------------------------------------