├── .gitignore ├── examples ├── postgres │ ├── .gitignore │ ├── create_db.sql │ ├── Cargo.toml │ └── src │ │ ├── pool.rs │ │ └── main.rs ├── tokio-tophat │ ├── .gitignore │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── bench.rs ├── bench-multi.rs ├── basic.rs ├── observability_bytes_written.rs ├── errors.rs ├── routing.rs ├── server_sent_events.rs ├── cors.rs ├── server_sent_events_2.rs ├── errors_verbose.rs ├── middleware.rs └── identity.rs ├── src ├── chunked │ ├── mod.rs │ └── encoder.rs ├── request.rs ├── response.rs ├── client │ ├── mod.rs │ ├── error.rs │ ├── decode.rs │ └── encode.rs ├── lib.rs ├── server │ ├── error.rs │ ├── mod.rs │ ├── router.rs │ ├── encode.rs │ ├── response_writer.rs │ ├── identity.rs │ ├── glitch.rs │ ├── decode.rs │ └── cors.rs ├── timeout.rs ├── trailers.rs ├── util.rs └── body.rs ├── CHANGELOG.md ├── LICENSE-MIT ├── justfile ├── Cargo.toml ├── tests ├── client_basic.rs ├── server_error_handling.rs ├── mock.rs └── server_basic.rs ├── .github └── workflows │ └── ci.yaml ├── NOTES.md ├── README.md └── LICENSE-APACHE /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | .cookie 4 | .env 5 | -------------------------------------------------------------------------------- /examples/postgres/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | .env 4 | -------------------------------------------------------------------------------- /examples/tokio-tophat/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | .env 4 | -------------------------------------------------------------------------------- /src/chunked/mod.rs: -------------------------------------------------------------------------------- 1 | mod decoder; 2 | mod encoder; 3 | 4 | pub(crate) use decoder::ChunkedDecoder; 5 | pub(crate) use encoder::ChunkedEncoder; 6 | -------------------------------------------------------------------------------- /src/request.rs: -------------------------------------------------------------------------------- 1 | use http::Request as HttpRequest; 2 | 3 | use crate::body::Body; 4 | 5 | /// Currently, Request is not generic over Body type 6 | pub type Request = HttpRequest; 7 | -------------------------------------------------------------------------------- /src/response.rs: -------------------------------------------------------------------------------- 1 | use http::Response as HttpResponse; 2 | 3 | use crate::body::Body; 4 | 5 | /// Currently, Response is not generic over Body type 6 | pub type Response = HttpResponse; 7 | -------------------------------------------------------------------------------- /examples/postgres/create_db.sql: -------------------------------------------------------------------------------- 1 | create table if not exists test_users (id int, planet text, organization text); 2 | 3 | insert into test_users (id, planet, organization) values 4 | (1,'tatooine','jedi'), 5 | (2,'tatooine','new republic'), 6 | (3,'nevarro','guild'), 7 | (4,'mandalore','guild'); 8 | -------------------------------------------------------------------------------- /examples/tokio-tophat/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tokio-tophat" 3 | version = "0.1.0" 4 | authors = ["Walther Chen "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | async-dup = "1.2.2" 11 | futures-lite = "1.11.3" 12 | tokio = { version = "1", features = ["full"] } 13 | tophat = { path = "../../" } 14 | http = "0.2.2" 15 | tokio-util = { version = "0.6.0", features = ["compat"] } 16 | tracing-subscriber = "0.2.15" 17 | -------------------------------------------------------------------------------- /src/client/mod.rs: -------------------------------------------------------------------------------- 1 | //! Simple client for HTTP/1.1 2 | 3 | mod decode; 4 | mod encode; 5 | mod error; 6 | 7 | use futures_lite::{io, AsyncRead, AsyncWrite}; 8 | 9 | use crate::{Request, Response}; 10 | use decode::decode; 11 | use encode::Encoder; 12 | use error::ClientError; 13 | 14 | /// Opens an HTTP/1.1 connection to a remote host. 15 | pub async fn connect(mut stream: RW, req: Request) -> Result 16 | where 17 | RW: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, 18 | { 19 | let mut req = Encoder::encode(req).await?; 20 | 21 | io::copy(&mut req, &mut stream).await.map_err(error::io)?; 22 | 23 | let res = decode(stream).await?; 24 | 25 | Ok(res) 26 | } 27 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(unsafe_code)] 2 | #![warn(missing_docs)] 3 | 4 | //! # tophat 5 | //! 6 | //! A small, pragmatic, and flexible async HTTP server library. 7 | //! 8 | //! More docs coming soon! For now, please see the examples directory for features. 9 | //! 10 | //! Also, please note that you'll need to set up your own async runtime to work with tophat. All 11 | //! the examples use `smol` as the runtime. 12 | 13 | mod body; 14 | mod chunked; 15 | pub mod client; 16 | mod request; 17 | mod response; 18 | pub mod server; 19 | mod timeout; 20 | pub mod trailers; 21 | mod util; 22 | 23 | /// Re-export http crate for convenience 24 | pub use http; 25 | 26 | pub use body::Body; 27 | pub use request::Request; 28 | pub use response::Response; 29 | -------------------------------------------------------------------------------- /examples/postgres/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "postgres-example" 3 | version = "0.1.0" 4 | authors = ["Walther Chen "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | anyhow = "1.0.37" 11 | async-dup = "1.2.2" 12 | async-trait = "0.1.42" 13 | deadpool = "0.7.0" 14 | dotenv = "0.15.0" 15 | futures-lite = "1.11.3" 16 | tophat = { path = "../../", features = ["router"] } 17 | http = "0.2.2" 18 | smol = "1.2.5" 19 | thiserror = "1.0.23" 20 | tokio-postgres = { version = "0.7.0", default_features = false } 21 | tokio-util = { version = "0.6.0", features = ["compat"] } 22 | tracing = "0.1.22" 23 | tracing-subscriber = "0.2.15" 24 | url = "2.2.0" 25 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 2020-05-19, v0.2.0 2 | ## Features 3 | - `ResponseWriter` now holds a `Response`, so it's not need to create one separately. 4 | - Convenience methods on `ResponseWriter`. 5 | - `Glitch` and `GlitchExt` for error management to error response. 6 | - `ResponseWritten` no longer creatable by user. 7 | - Remove unwrap macros (were they ever a good idea?). 8 | - Router now behind feature gate. 9 | - Cors, feature gated. 10 | - Identity, feature gated. 11 | - "Middleware" philosophy confirmed. (no specific framework for it) 12 | - Beginning of docs. 13 | 14 | 15 | ## Internal 16 | - remove `mime` crate. 17 | - pub use `http` crate. 18 | - remove more unwraps. 19 | - ci on all features 20 | - remove clippy on stable (nightly has different lints) 21 | - anyhow was added then removed. 22 | -------------------------------------------------------------------------------- /examples/bench.rs: -------------------------------------------------------------------------------- 1 | use async_dup::Arc; 2 | use smol::Async; 3 | use std::net::TcpListener; 4 | use tophat::server::accept; 5 | 6 | fn main() -> Result<(), Box> { 7 | let listener = Async::::bind(([127,0,0,1],9999))?; 8 | 9 | smol::block_on(async { 10 | loop { 11 | let (stream, _) = listener.accept().await?; 12 | let stream = Arc::new(stream); 13 | 14 | let task = smol::spawn(async move { 15 | let serve = accept(stream, |_req, resp_wtr| async { resp_wtr.send().await }).await; 16 | 17 | if let Err(err) = serve { 18 | eprintln!("Error: {}", err); 19 | } 20 | }); 21 | 22 | task.detach(); 23 | } 24 | }) 25 | } 26 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 tophat Developers 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /justfile: -------------------------------------------------------------------------------- 1 | watch: 2 | cargo watch -x 'check --all-features --examples --tests' 3 | 4 | test filter='': 5 | cargo watch -x 'test {{filter}}' 6 | 7 | test-anyhow: 8 | cargo watch -x 'test --features=anyhow -- --nocapture' 9 | 10 | bench: 11 | RUST_LOG=info cargo watch -x 'run --release --example bench' 12 | 13 | basic: 14 | RUST_LOG=info cargo watch -x 'run --release --example basic' 15 | 16 | routing: 17 | RUST_LOG=info cargo watch -x 'run --release --example routing --features="router"' 18 | 19 | routing_2: 20 | RUST_LOG=info cargo watch -x 'run --release --example routing_2 --features="router"' 21 | 22 | identity: 23 | RUST_LOG=info cargo watch --ignore .cookie -x 'run --release --example identity --features="router identity"' 24 | 25 | identity-login: 26 | curl -v --cookie .cookie --cookie-jar .cookie --location localhost:9999/login/test 27 | 28 | identity-hello: 29 | curl -v --cookie .cookie localhost:9999/ 30 | 31 | identity-logout: 32 | curl -v --cookie .cookie --cookie-jar .cookie --location localhost:9999/logout 33 | 34 | clippy: 35 | cargo watch -x '+nightly clippy --all-features -- -D warnings -Z unstable-options' 36 | -------------------------------------------------------------------------------- /examples/bench-multi.rs: -------------------------------------------------------------------------------- 1 | use async_channel::unbounded; 2 | use async_dup::Arc; 3 | use easy_parallel::Parallel; 4 | use smol::{future, Async, Executor}; 5 | use std::net::TcpListener; 6 | use tophat::server::accept; 7 | 8 | fn main() -> Result<(), Box> { 9 | let ex = Executor::new(); 10 | let (signal, shutdown) = unbounded::<()>(); 11 | 12 | Parallel::new() 13 | .each(0..num_cpus::get().max(1), |_| future::block_on(ex.run(shutdown.recv()))) 14 | .finish(|| future::block_on(async { 15 | drop(signal); 16 | })); 17 | 18 | let listener = Async::::bind(([127,0,0,1],9999))?; 19 | 20 | smol::block_on(async { 21 | loop { 22 | let (stream, _) = listener.accept().await?; 23 | let stream = Arc::new(stream); 24 | 25 | let task = smol::spawn(async move { 26 | let serve = accept(stream, |_req, resp_wtr| async { resp_wtr.send().await }).await; 27 | 28 | if let Err(err) = serve { 29 | eprintln!("Error: {}", err); 30 | } 31 | }); 32 | 33 | task.detach(); 34 | } 35 | }) 36 | } 37 | -------------------------------------------------------------------------------- /src/server/error.rs: -------------------------------------------------------------------------------- 1 | //! Errors that indicate system failure, user error in using tophat, or closed connection. 2 | //! 3 | //! "App" errors, which are handled within an endpoint and result only in loggin and an Http 4 | //! Response, are handled by `Glitch`. 5 | 6 | use std::fmt; 7 | 8 | /// Public Errors (does not include internal fails) 9 | #[derive(Debug)] 10 | pub enum ServerError { 11 | /// Error because tophat does not support the transfer encoding. 12 | ConnectionClosedUnsupportedTransferEncoding, 13 | 14 | /// Connection lost 15 | ConnectionLost(std::io::Error), 16 | } 17 | 18 | impl std::error::Error for ServerError { 19 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 20 | use ServerError::*; 21 | match self { 22 | ConnectionClosedUnsupportedTransferEncoding => None, 23 | ConnectionLost(err) => Some(err), 24 | } 25 | } 26 | } 27 | 28 | impl fmt::Display for ServerError { 29 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 30 | use ServerError::*; 31 | match self { 32 | ConnectionClosedUnsupportedTransferEncoding => { 33 | write!(f, "Connection closed: Unsupported Transfer Encoding") 34 | } 35 | ConnectionLost(err) => write!(f, "Connection lost: {}", err), 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /examples/basic.rs: -------------------------------------------------------------------------------- 1 | use async_dup::Arc; 2 | use http::header; 3 | use smol::Async; 4 | use std::net::TcpListener; 5 | use tophat::server::accept; 6 | 7 | fn main() -> Result<(), Box> { 8 | tracing_subscriber::fmt::init(); 9 | 10 | let listener = Async::::bind(([127,0,0,1],9999))?; 11 | 12 | smol::block_on(async { 13 | loop { 14 | let (stream, _) = listener.accept().await?; 15 | let stream = Arc::new(stream); 16 | 17 | let task = smol::spawn(async move { 18 | let serve = accept(stream, |req, mut resp_wtr| async { 19 | println!("{:?}", *req.uri()); 20 | println!("{:?}", req.version()); 21 | println!("{:?}", req.method()); 22 | println!("{:?}", req.headers().get(header::CONTENT_LENGTH)); 23 | println!("{:?}", req.headers().get(header::HOST)); 24 | 25 | let req_body = req.into_body().into_string().await?; 26 | let resp_body = format!("Hello, {}!", req_body); 27 | resp_wtr.set_body(resp_body.into()); 28 | 29 | resp_wtr.send().await 30 | }) 31 | .await; 32 | 33 | if let Err(err) = serve { 34 | eprintln!("Error: {}", err); 35 | } 36 | }); 37 | 38 | task.detach(); 39 | } 40 | }) 41 | } 42 | -------------------------------------------------------------------------------- /examples/observability_bytes_written.rs: -------------------------------------------------------------------------------- 1 | use async_dup::Arc; 2 | use http::header; 3 | use smol::Async; 4 | use std::net::TcpListener; 5 | use tophat::server::accept; 6 | 7 | fn main() -> Result<(), Box> { 8 | tracing_subscriber::fmt::init(); 9 | 10 | let listener = Async::::bind(([127,0,0,1],9999))?; 11 | 12 | smol::block_on(async { 13 | loop { 14 | let (stream, _) = listener.accept().await?; 15 | let stream = Arc::new(stream); 16 | 17 | let task = smol::spawn(async move { 18 | let serve = accept(stream, |req, mut resp_wtr| async { 19 | println!("{:?}", *req.uri()); 20 | println!("{:?}", req.version()); 21 | println!("{:?}", req.method()); 22 | println!("{:?}", req.headers().get(header::CONTENT_LENGTH)); 23 | println!("{:?}", req.headers().get(header::HOST)); 24 | 25 | let req_body = req.into_body().into_string().await?; 26 | let resp_body = format!("Hello, {}!", req_body); 27 | resp_wtr.set_body(resp_body.into()); 28 | 29 | let done = resp_wtr.send().await?; 30 | 31 | println!("Bytes written: {}", done.bytes_written()); 32 | 33 | Ok(done) 34 | 35 | }) 36 | .await; 37 | 38 | if let Err(err) = serve { 39 | eprintln!("Error: {}", err); 40 | } 41 | }); 42 | 43 | task.detach(); 44 | } 45 | }) 46 | } 47 | -------------------------------------------------------------------------------- /src/timeout.rs: -------------------------------------------------------------------------------- 1 | // From async-std future::timeout, except that futures_timer is swapped in. 2 | 3 | use std::error::Error; 4 | use std::fmt; 5 | use std::future::Future; 6 | use std::pin::Pin; 7 | use std::task::{Context, Poll}; 8 | use std::time::Duration; 9 | 10 | use futures_timer::Delay; 11 | use pin_project_lite::pin_project; 12 | 13 | 14 | pub(crate) async fn timeout(dur: Duration, f: F) -> Result 15 | where 16 | F: Future, 17 | { 18 | let f = TimeoutFuture { 19 | future: f, 20 | delay: Delay::new(dur), 21 | }; 22 | f.await 23 | } 24 | 25 | pin_project! { 26 | /// A future that times out after a duration of time. 27 | pub(crate) struct TimeoutFuture { 28 | #[pin] 29 | future: F, 30 | #[pin] 31 | delay: Delay, 32 | } 33 | } 34 | 35 | impl TimeoutFuture { 36 | #[allow(dead_code)] 37 | pub(crate) fn new(future: F, dur: Duration) -> TimeoutFuture { 38 | TimeoutFuture { 39 | future, 40 | delay: Delay::new(dur), 41 | } 42 | } 43 | } 44 | 45 | impl Future for TimeoutFuture { 46 | type Output = Result; 47 | 48 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 49 | let this = self.project(); 50 | match this.future.poll(cx) { 51 | Poll::Ready(v) => Poll::Ready(Ok(v)), 52 | Poll::Pending => match this.delay.poll(cx) { 53 | Poll::Ready(_) => Poll::Ready(Err(TimeoutError { _private: () })), 54 | Poll::Pending => Poll::Pending, 55 | }, 56 | } 57 | } 58 | } 59 | 60 | /// An error returned when a future times out. 61 | #[derive(Clone, Copy, Debug, Eq, PartialEq)] 62 | pub(crate) struct TimeoutError { 63 | _private: (), 64 | } 65 | 66 | impl Error for TimeoutError {} 67 | 68 | impl fmt::Display for TimeoutError { 69 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 70 | "future has timed out".fmt(f) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/trailers.rs: -------------------------------------------------------------------------------- 1 | //! # Trailers 2 | //! 3 | //! Trailers are headers sent at the end of a chunked message. 4 | //! 5 | //! Currently tophat can only receive, not send them. 6 | 7 | use async_channel::Sender; 8 | use http::HeaderMap; 9 | use std::ops::{Deref, DerefMut}; 10 | 11 | use crate::body::error::BodyError; 12 | 13 | /// A collection of trailing HTTP headers. 14 | #[derive(Debug)] 15 | pub struct Trailers { 16 | /// The headers in a trailer 17 | pub headers: HeaderMap, 18 | } 19 | 20 | impl Trailers { 21 | /// Create a new instance of `Trailers`. 22 | pub fn new() -> Self { 23 | Self::default() 24 | } 25 | } 26 | 27 | impl Default for Trailers { 28 | fn default() -> Self { 29 | Self { 30 | headers: HeaderMap::new(), 31 | } 32 | } 33 | } 34 | 35 | impl Clone for Trailers { 36 | fn clone(&self) -> Self { 37 | Self { 38 | headers: self.headers.clone(), 39 | } 40 | } 41 | } 42 | 43 | impl Deref for Trailers { 44 | type Target = HeaderMap; 45 | 46 | fn deref(&self) -> &Self::Target { 47 | &self.headers 48 | } 49 | } 50 | 51 | impl DerefMut for Trailers { 52 | fn deref_mut(&mut self) -> &mut Self::Target { 53 | &mut self.headers 54 | } 55 | } 56 | 57 | /// The sending half of a channel to send trailers. 58 | /// 59 | /// Unlike `async_std::sync::channel` the `send` method on this type can only be 60 | /// called once, and cannot be cloned. That's because only a single instance of 61 | /// `Trailers` should be created. 62 | #[derive(Debug)] 63 | pub struct TrailersSender { 64 | sender: Sender>, 65 | } 66 | 67 | impl TrailersSender { 68 | /// Create a new instance of `TrailersSender`. 69 | #[doc(hidden)] 70 | pub(crate) fn new(sender: Sender>) -> Self { 71 | Self { sender } 72 | } 73 | 74 | /// Send a `Trailer`. 75 | /// 76 | /// The channel will be consumed after having sent trailers. 77 | pub(crate) async fn send(self, trailers: Result) { 78 | // TODO should this return an error? 79 | let _ = self.sender.send(trailers).await; 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /examples/errors.rs: -------------------------------------------------------------------------------- 1 | use async_dup::Arc; 2 | use futures_util::io::{AsyncRead, AsyncWrite}; 3 | use http::Method; 4 | use smol::Async; 5 | use std::net::TcpListener; 6 | use tophat::{ 7 | server::{ 8 | accept, 9 | glitch::{Glitch, Result}, 10 | router::Router, 11 | ResponseWriter, ResponseWritten, 12 | }, 13 | Request, 14 | }; 15 | 16 | fn main() -> std::result::Result<(), Box> { 17 | tracing_subscriber::fmt::init(); 18 | 19 | let router = Router::build() 20 | .data("Data from datastore") 21 | .at(Method::GET, "/database_error", database_error) 22 | .at(Method::GET, "/missing_data", missing_data) 23 | .finish(); 24 | 25 | let listener = Async::::bind(([127,0,0,1],9999))?; 26 | 27 | smol::block_on(async { 28 | loop { 29 | let router = router.clone(); 30 | 31 | let (stream, _) = listener.accept().await?; 32 | let stream = Arc::new(stream); 33 | 34 | let task = smol::spawn(async move { 35 | let serve = accept(stream, |req, resp_wtr| async { 36 | let res = router.route(req, resp_wtr).await; 37 | res 38 | }) 39 | .await; 40 | 41 | if let Err(err) = serve { 42 | eprintln!("Error: {}", err); 43 | } 44 | }); 45 | 46 | task.detach(); 47 | } 48 | }) 49 | } 50 | 51 | async fn database_error(_req: Request, resp_wtr: ResponseWriter) -> Result 52 | where 53 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 54 | { 55 | use std::io; 56 | 57 | let failed_db = Err(io::Error::new(io::ErrorKind::Other, "")); 58 | failed_db?; // returns a 500 automatically. 59 | 60 | resp_wtr.send().await 61 | } 62 | 63 | async fn missing_data(_req: Request, resp_wtr: ResponseWriter) -> Result 64 | where 65 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 66 | { 67 | let failed_db = None; 68 | 69 | // Manually create a 400 70 | // This will work even without anyhow integration. 71 | failed_db.ok_or_else(|| Glitch::bad_request())?; 72 | 73 | resp_wtr.send().await 74 | } 75 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tophat" 3 | version = "0.3.0" 4 | authors = ["Walther Chen "] 5 | description = "A small, pragmatic, and flexible async http server" 6 | keywords = ["http"] 7 | categories = ["web-programming::http-server"] 8 | edition = "2018" 9 | license = "MIT OR Apache-2.0" 10 | repository = "https://github.com/hwchen/tophat" 11 | readme = "README.md" 12 | 13 | [package.metadata.docs.rs] 14 | all-features = true 15 | 16 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 17 | 18 | [dependencies] 19 | async-channel = "1.5.1" 20 | async-dup = "1.2.2" 21 | byte-pool = "0.2.2" 22 | futures-lite = "1.11.3" 23 | futures-timer = "3.0.2" 24 | futures-util = { version = "0.3.8", features = ["io"] } 25 | http = "0.2.2" 26 | httpdate = "0.3.2" 27 | httparse = "1.3.4" 28 | lazy_static = "1.4.0" 29 | pin-project-lite = "0.2.0" 30 | tracing = "0.1.22" 31 | 32 | # for router 33 | path-tree = { version = "0.1.12", optional = true } 34 | type-map = { version = "0.3.0", optional = true } 35 | 36 | # for identity 37 | cookie = { version = "0.14.3", optional = true } 38 | jsonwebtoken = { version = "7.2.0", optional = true } 39 | serde = { version = "1.0.118", features = ["derive"], optional = true } 40 | time = { version = "0.2.23", default_features = false, optional = true } 41 | 42 | # for cors (maybe use elsewhere?) 43 | headers = { version = "0.3.2", optional = true } 44 | 45 | [features] 46 | # Nothing enabled by default 47 | default = [] 48 | 49 | cors = ["headers"] 50 | 51 | router = [ 52 | "path-tree", 53 | "type-map", 54 | ] 55 | 56 | identity = [ 57 | "cookie", 58 | "jsonwebtoken", 59 | "serde", 60 | "time", 61 | ] 62 | 63 | [dev-dependencies] 64 | async-channel = "1.5.1" 65 | async-dup = "1.2.2" 66 | easy-parallel = "3.1.0" 67 | futures = "0.3.8" 68 | num_cpus = "1.13.0" 69 | smol = "1.2.5" 70 | tracing-subscriber = "0.2.15" 71 | 72 | [[example]] 73 | name = "routing" 74 | required-features = ["router"] 75 | 76 | [[example]] 77 | name = "identity" 78 | required-features = ["identity", "router"] 79 | 80 | [[example]] 81 | name = "middleware" 82 | required-features = ["router"] 83 | 84 | [[example]] 85 | name = "cors" 86 | required-features = ["router", "cors"] 87 | 88 | [[example]] 89 | name = "errors" 90 | required-features = ["router"] 91 | 92 | [[example]] 93 | name = "errors_verbose" 94 | required-features = ["router"] 95 | -------------------------------------------------------------------------------- /examples/tokio-tophat/src/main.rs: -------------------------------------------------------------------------------- 1 | use async_dup::{Arc, Mutex}; 2 | use std::io; 3 | use std::pin::Pin; 4 | use std::task::{Context, Poll}; 5 | use tokio::net::{self, TcpStream}; 6 | use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; 7 | use tophat::server::accept; 8 | 9 | #[tokio::main] 10 | async fn main() -> Result<(), Box> { 11 | tracing_subscriber::fmt::init(); 12 | 13 | let listener = net::TcpListener::bind("127.0.0.1:9999").await?; 14 | 15 | loop { 16 | let (stream, _) = listener.accept().await?; 17 | let stream = WrapStream::new(stream); 18 | 19 | tokio::spawn(async move { 20 | let serve = accept(stream, |_req, mut resp_wtr| async { 21 | let resp_body = "Hello, World!"; 22 | resp_wtr.set_body(resp_body.into()); 23 | 24 | resp_wtr.send().await 25 | }) 26 | .await; 27 | 28 | if let Err(err) = serve { 29 | eprintln!("Error: {}", err); 30 | } 31 | }); 32 | } 33 | } 34 | 35 | // TODO I'm not sure this is the best way to do this. Suggestions for simplifying definitely 36 | // welcome. When AsyncRead and AsyncWrite standardized, this shouldn't be necessary. 37 | #[derive(Clone)] 38 | struct WrapStream(Arc>>); 39 | 40 | impl WrapStream { 41 | fn new(stream: TcpStream) -> Self { 42 | let stream = stream.compat_write(); 43 | WrapStream(Arc::new(Mutex::new(stream))) 44 | } 45 | } 46 | 47 | impl futures_lite::AsyncRead for WrapStream { 48 | fn poll_read( 49 | self: Pin<&mut Self>, 50 | cx: &mut Context<'_>, 51 | buf: &mut [u8], 52 | ) -> Poll> { 53 | Pin::new(&mut *(&*self).0.lock()).poll_read(cx, buf) 54 | } 55 | } 56 | 57 | impl futures_lite::AsyncWrite for WrapStream { 58 | fn poll_write( 59 | self: Pin<&mut Self>, 60 | cx: &mut Context<'_>, 61 | buf: &[u8], 62 | ) -> Poll> { 63 | Pin::new(&mut *(&*self).0.lock()).poll_write(cx, buf) 64 | } 65 | 66 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 67 | Pin::new(&mut *(&*self).0.lock()).poll_flush(cx) 68 | } 69 | 70 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 71 | Pin::new(&mut *(&*self).0.lock()).poll_close(cx) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /examples/routing.rs: -------------------------------------------------------------------------------- 1 | use async_dup::Arc; 2 | use futures_util::io::{AsyncRead, AsyncWrite}; 3 | use http::Method; 4 | use smol::Async; 5 | use std::net::TcpListener; 6 | use tophat::{ 7 | server::{ 8 | accept, 9 | glitch::Result, 10 | router::{Router, RouterRequestExt}, 11 | ResponseWriter, ResponseWritten, 12 | }, 13 | Request, 14 | }; 15 | 16 | fn main() -> std::result::Result<(), Box> { 17 | tracing_subscriber::fmt::init(); 18 | 19 | let router = Router::build() 20 | .data("Data from datastore") 21 | .at(Method::GET, "/:name", hello_user) 22 | .at(Method::GET, "/", blank) 23 | .finish(); 24 | 25 | let listener = Async::::bind(([127,0,0,1],9999))?; 26 | 27 | smol::block_on(async { 28 | loop { 29 | let router = router.clone(); 30 | 31 | let (stream, _) = listener.accept().await?; 32 | let stream = Arc::new(stream); 33 | 34 | let task = smol::spawn(async move { 35 | let serve = accept(stream, |req, resp_wtr| async { 36 | let res = router.route(req, resp_wtr).await; 37 | res 38 | }) 39 | .await; 40 | 41 | if let Err(err) = serve { 42 | eprintln!("Error: {}", err); 43 | } 44 | }); 45 | 46 | task.detach(); 47 | } 48 | }) 49 | } 50 | 51 | async fn hello_user(req: Request, mut resp_wtr: ResponseWriter) -> Result 52 | where 53 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 54 | { 55 | //smol::Timer::after(std::time::Duration::from_secs(5)).await; 56 | 57 | let mut resp_body = format!("Hello, "); 58 | 59 | // add params to body string 60 | if let Some(params) = req.params() { 61 | for (k, v) in params { 62 | resp_body.push_str(&format!("{} = {}", k, v)); 63 | } 64 | } 65 | 66 | // add data to body string 67 | if let Some(data_string) = req.data::<&str>() { 68 | resp_body.push_str(&format!(" and {}", *data_string)); 69 | } 70 | 71 | resp_wtr.set_body(resp_body.into()); 72 | 73 | resp_wtr.send().await 74 | } 75 | 76 | async fn blank(_req: Request, resp_wtr: ResponseWriter) -> Result 77 | where 78 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 79 | { 80 | resp_wtr.send().await 81 | } 82 | -------------------------------------------------------------------------------- /examples/postgres/src/pool.rs: -------------------------------------------------------------------------------- 1 | //! Pool for db connections 2 | //! 3 | //! Based on deadpool-postgres. For more config and features (like prepared statements), use its 4 | //! implementation as a starting point. 5 | 6 | use anyhow::Context as _; 7 | use async_trait::async_trait; 8 | use smol::Async; 9 | use std::net::{TcpStream, ToSocketAddrs}; 10 | use tokio_postgres::{tls::NoTls, Client}; 11 | use tokio_util::compat::FuturesAsyncWriteCompatExt; 12 | use tracing::debug; 13 | 14 | pub(crate) type Pool = deadpool::managed::Pool; 15 | pub (crate) type RecycleError = deadpool::managed::RecycleError; 16 | 17 | pub (crate) struct Manager { 18 | pg_config: tokio_postgres::config::Config, 19 | socket_addr: std::net::SocketAddr, 20 | } 21 | 22 | impl Manager { 23 | pub(crate) fn new(db_url: &str) -> Result { 24 | let pg_config = db_url.parse()?; 25 | 26 | let db_url: url::Url = db_url.parse()?; 27 | // Figure out the host and the port. 28 | let host = db_url.host().context("cannot parse host")?.to_string(); 29 | let port = db_url 30 | .port() 31 | .unwrap_or(5432); 32 | 33 | // Connect to the host. 34 | let socket_addr = (host.as_str(), port) 35 | .to_socket_addrs()? 36 | .next() 37 | .context("cannot resolve address")?; 38 | 39 | Ok(Self { 40 | pg_config, 41 | socket_addr, 42 | }) 43 | } 44 | } 45 | 46 | #[async_trait] 47 | impl deadpool::managed::Manager for Manager { 48 | async fn create(&self) -> Result { 49 | debug!("Pool: create client"); 50 | let stream = Async::::connect(self.socket_addr).await?; 51 | let stream = stream.compat_write(); 52 | let (client, connection) = self.pg_config.connect_raw(stream, NoTls).await?; 53 | smol::spawn(connection).detach(); 54 | 55 | Ok(client) 56 | } 57 | 58 | async fn recycle(&self, client: &mut Client) -> Result<(), RecycleError> { 59 | debug!("Pool: recycle client"); 60 | if client.is_closed() { 61 | return Err(RecycleError::Message("Connection closed".to_string())); 62 | } 63 | // "fast" recycling method from doesn't run a query 64 | //client.simple_query(None).await 65 | Ok(()) 66 | } 67 | } 68 | 69 | #[derive(Debug, thiserror::Error)] 70 | pub(crate) enum Error { 71 | #[error("io error")] 72 | Io(#[from] std::io::Error), 73 | #[error("postgres error")] 74 | Postgres(#[from] tokio_postgres::Error), 75 | } 76 | -------------------------------------------------------------------------------- /examples/server_sent_events.rs: -------------------------------------------------------------------------------- 1 | use async_dup::{Arc, Mutex}; 2 | use futures::Stream; 3 | use smol::Async; 4 | use std::net::TcpListener; 5 | use std::pin::Pin; 6 | use std::task::{Context, Poll}; 7 | use std::time::Duration; 8 | use tophat::server::accept; 9 | 10 | fn main() -> Result<(), Box> { 11 | tracing_subscriber::fmt::init(); 12 | 13 | let ping_machine = Arc::new(Mutex::new(PingMachine { 14 | broadcasters: Vec::new(), 15 | })); 16 | 17 | let listener = Async::::bind(([127,0,0,1],9999))?; 18 | 19 | let ping_task = smol::spawn({ 20 | let ping_machine = ping_machine.clone(); 21 | async move { 22 | loop { 23 | ping_machine.lock().ping().await; 24 | smol::Timer::after(Duration::from_secs(1)).await; 25 | } 26 | } 27 | }); 28 | ping_task.detach(); 29 | 30 | smol::block_on(async { 31 | loop { 32 | let (stream, _) = listener.accept().await?; 33 | let stream = Arc::new(stream); 34 | 35 | let ping_machine = ping_machine.clone(); 36 | 37 | let task = smol::spawn(async move { 38 | let serve = accept(stream, |_req, mut resp_wtr| async { 39 | let client = ping_machine.lock().add_client(); 40 | resp_wtr.set_sse(client); 41 | 42 | resp_wtr.send().await 43 | }) 44 | .await; 45 | 46 | if let Err(err) = serve { 47 | eprintln!("Error: {}", err); 48 | } 49 | }); 50 | 51 | task.detach(); 52 | } 53 | }) 54 | } 55 | 56 | struct PingMachine { 57 | broadcasters: Vec>, 58 | } 59 | 60 | impl PingMachine { 61 | async fn ping(&self) { 62 | for tx in &self.broadcasters { 63 | let _ = tx.send("data: ping\n\n".to_owned()).await; 64 | } 65 | } 66 | 67 | fn add_client(&mut self) -> Client { 68 | let (tx, rx) = async_channel::bounded(10); 69 | 70 | self.broadcasters.push(tx); 71 | 72 | Client(rx) 73 | } 74 | } 75 | 76 | struct Client(async_channel::Receiver); 77 | 78 | impl Stream for Client { 79 | type Item = Result; 80 | 81 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 82 | match Pin::new(&mut self.0).poll_next(cx) { 83 | Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))), 84 | Poll::Ready(None) => Poll::Ready(None), 85 | Poll::Pending => Poll::Pending, 86 | } 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /tests/client_basic.rs: -------------------------------------------------------------------------------- 1 | //! Client tests are far from comprehensive; more tests are welcome. 2 | 3 | mod mock; 4 | 5 | use http::{ header, StatusCode, Uri}; 6 | use tophat::{client::connect, Body, Request}; 7 | 8 | use mock::Server; 9 | 10 | const RESP_200: &str = "HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n"; 11 | const RESP_400: &str = "HTTP/1.1 400 Bad Request\r\ncontent-length: 0\r\n\r\n"; 12 | 13 | #[test] 14 | fn test_client_empties() { 15 | smol::block_on(async { 16 | // expected req, and sends a resp200 17 | let testserver = Server::new_with_writes( 18 | "GET /foo/bar HTTP/1.1\r\ncontent-length: 0\r\nhost: example.org\r\n\r\n", 19 | RESP_200, 20 | 1, 21 | ); 22 | 23 | let mut req = Request::new(Body::empty()); 24 | // TODO make Host compile time error? 25 | req.headers_mut().insert(header::HOST, "example.org".parse().unwrap()); 26 | *req.uri_mut() = "/foo/bar".parse::().unwrap(); 27 | 28 | let resp = connect(testserver.clone(), req).await.unwrap(); 29 | 30 | testserver.assert(); 31 | assert_eq!(resp.status(), StatusCode::OK); 32 | }); 33 | } 34 | 35 | #[test] 36 | fn test_client_bad_request() { 37 | smol::block_on(async { 38 | // expected req, and sends a resp200 39 | let testserver = Server::new_with_writes( 40 | "GET /foo/bar HTTP/1.1\r\ncontent-length: 0\r\nhost: example.org\r\n\r\n", 41 | RESP_400, 42 | 1, 43 | ); 44 | 45 | let mut req = Request::new(Body::empty()); 46 | // TODO make Host compile time error? 47 | req.headers_mut().insert(header::HOST, "example.org".parse().unwrap()); 48 | *req.uri_mut() = "/foo/bar".parse::().unwrap(); 49 | 50 | let resp = connect(testserver.clone(), req).await.unwrap(); 51 | 52 | testserver.assert(); 53 | assert_eq!(resp.status(), StatusCode::BAD_REQUEST); 54 | }); 55 | } 56 | 57 | #[test] 58 | fn test_client_body_query() { 59 | smol::block_on(async { 60 | // expected req, and sends a resp200 61 | let testserver = Server::new_with_writes( 62 | "GET /foo/bar?one=two HTTP/1.1\r\ncontent-length: 6\r\nhost: example.org\r\n\r\ntophat", 63 | RESP_200, 64 | 1, 65 | ); 66 | 67 | let mut req = Request::new("tophat".into()); 68 | // TODO make Host compile time error? 69 | req.headers_mut().insert(header::HOST, "example.org".parse().unwrap()); 70 | *req.uri_mut() = "/foo/bar?one=two".parse::().unwrap(); 71 | 72 | let resp = connect(testserver.clone(), req).await.unwrap(); 73 | 74 | testserver.assert(); 75 | assert_eq!(resp.status(), StatusCode::OK); 76 | }); 77 | } 78 | -------------------------------------------------------------------------------- /examples/cors.rs: -------------------------------------------------------------------------------- 1 | //! Cors example 2 | //! 3 | //! Careful, it's not completely automatic middleware, you currently have to use a switch statement 4 | //! to get the correct early-return behavior. 5 | 6 | use async_dup::Arc; 7 | use futures_util::io::{AsyncRead, AsyncWrite}; 8 | use http::Method; 9 | use smol::Async; 10 | use std::net::TcpListener; 11 | use tophat::{ 12 | server::{ 13 | accept, 14 | cors::Cors, 15 | glitch::Result, 16 | router::{Router, RouterRequestExt}, 17 | ResponseWriter, ResponseWritten, 18 | }, 19 | Request, 20 | }; 21 | 22 | fn main() -> std::result::Result<(), Box> { 23 | tracing_subscriber::fmt::init(); 24 | 25 | let cors = Cors::build() 26 | .allow_origin("http://example.com") 27 | .allow_methods(vec!["GET", "POST", "DELETE"]) 28 | .allow_header("content-type") 29 | .finish(); 30 | 31 | let router = Router::build() 32 | .data("Data from datastore") 33 | .at(Method::GET, "/:name", hello_user) 34 | .finish(); 35 | 36 | let listener = Async::::bind(([127,0,0,1],9999))?; 37 | 38 | smol::block_on(async { 39 | loop { 40 | let cors = cors.clone(); 41 | let router = router.clone(); 42 | 43 | let (stream, _) = listener.accept().await?; 44 | let stream = Arc::new(stream); 45 | 46 | let task = smol::spawn(async move { 47 | let serve = accept(stream, |req, mut resp_wtr| async { 48 | cors.validate(&req, &mut resp_wtr)?; 49 | 50 | // back to routing here 51 | let res = router.route(req, resp_wtr).await; 52 | res 53 | }) 54 | .await; 55 | 56 | if let Err(err) = serve { 57 | eprintln!("Error: {}", err); 58 | } 59 | }); 60 | 61 | task.detach(); 62 | } 63 | }) 64 | } 65 | 66 | async fn hello_user(req: Request, mut resp_wtr: ResponseWriter) -> Result 67 | where 68 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 69 | { 70 | //smol::Timer::after(std::time::Duration::from_secs(5)).await; 71 | 72 | let mut resp_body = format!("Hello, "); 73 | 74 | // add params to body string 75 | if let Some(params) = req.params() { 76 | for (k, v) in params { 77 | resp_body.push_str(&format!("{} = {}", k, v)); 78 | } 79 | } 80 | 81 | // add data to body string 82 | if let Some(data_string) = req.data::<&str>() { 83 | resp_body.push_str(&format!(" and {}", *data_string)); 84 | } 85 | 86 | resp_wtr.set_body(resp_body.into()); 87 | 88 | resp_wtr.send().await 89 | } 90 | -------------------------------------------------------------------------------- /examples/postgres/src/main.rs: -------------------------------------------------------------------------------- 1 | mod pool; 2 | 3 | use async_dup::Arc; 4 | use futures_lite::{AsyncRead, AsyncWrite}; 5 | use http::Method; 6 | use smol::Async; 7 | use std::env; 8 | use std::net::TcpListener; 9 | use tophat::{ 10 | server::{ 11 | accept, 12 | glitch, 13 | router::{Router, RouterRequestExt}, 14 | ResponseWriter, ResponseWritten, 15 | }, 16 | Request, 17 | }; 18 | 19 | use pool::{Pool, Manager}; 20 | 21 | fn main() -> Result<(), anyhow::Error> { 22 | dotenv::dotenv().ok(); 23 | tracing_subscriber::fmt::init(); 24 | 25 | // db setup 26 | let db_url = env::var("DATABASE_URL").expect("no db env var found"); 27 | let mgr = Manager::new(&db_url)?; 28 | let pool = Pool::new(mgr, 16); 29 | 30 | // router setup 31 | let router = Router::build() 32 | .data(pool) 33 | .at(Method::GET, "/", index) 34 | .at(Method::GET, "/planet/count", get_user_count_by_planet) 35 | .finish(); 36 | 37 | let listener = Async::::bind(([127,0,0,1],9999))?; 38 | 39 | smol::block_on(async { 40 | loop { 41 | let router = router.clone(); 42 | 43 | let (stream, _) = listener.accept().await?; 44 | let stream = Arc::new(stream); 45 | 46 | let task = smol::spawn(async move { 47 | let serve = accept(stream, |req, resp_wtr| async { 48 | let res = router.route(req, resp_wtr).await; 49 | res 50 | }) 51 | .await; 52 | 53 | if let Err(err) = serve { 54 | eprintln!("Error: {}", err); 55 | } 56 | }); 57 | 58 | task.detach(); 59 | } 60 | }) 61 | } 62 | 63 | async fn get_user_count_by_planet(req: Request, mut resp_wtr: ResponseWriter) -> glitch::Result 64 | where 65 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 66 | { 67 | let pool = req.data::().unwrap(); 68 | 69 | let client = pool.get().await?; 70 | 71 | let stmt = "SELECT planet, COUNT(*)::integer as count FROM test_users GROUP BY planet"; 72 | let rows = client.query(stmt, &[]).await?; 73 | 74 | let body = rows.iter() 75 | .map(|r| { 76 | let country: &str = r.get(0); 77 | let count: i32 = r.get(1); 78 | format!("{},{}\n", country, count) 79 | }).collect(); 80 | 81 | resp_wtr.set_text(body); 82 | 83 | resp_wtr.send().await 84 | } 85 | 86 | async fn index(_req: Request, mut resp_wtr: ResponseWriter) -> glitch::Result 87 | where 88 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 89 | { 90 | resp_wtr.set_text("still alive".into()); 91 | resp_wtr.send().await 92 | } 93 | -------------------------------------------------------------------------------- /src/client/error.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error as StdError; 2 | use std::fmt; 3 | 4 | type BoxError = Box; 5 | 6 | #[derive(Debug)] 7 | pub struct ClientError { 8 | kind: Kind, 9 | source: Option, 10 | } 11 | 12 | impl ClientError { 13 | pub(crate) fn new>(kind: Kind, err: Option) -> Self { 14 | Self { 15 | kind, 16 | source: err.map(Into::into), 17 | } 18 | } 19 | 20 | //// Returns the status code, if the error was generated from a response. 21 | //pub fn status(&self) -> Option { 22 | // match self.kind { 23 | // Kind::Status(code) => Some(code), 24 | // _ => None, 25 | // } 26 | //} 27 | } 28 | 29 | impl fmt::Display for ClientError { 30 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 31 | use Kind::*; 32 | match &self.kind { 33 | // TODO improve these messages 34 | Encode(msg) => { 35 | if let Some(ref err) = self.source { 36 | write!(f, "{:?}: {}", msg, err) 37 | } else { 38 | write!(f, "{:?}", msg) 39 | } 40 | } 41 | Decode(msg) => { 42 | if let Some(ref err) = self.source { 43 | write!(f, "{:?}: {}", msg, err) 44 | } else { 45 | write!(f, "{:?}", msg) 46 | } 47 | } 48 | Io => { 49 | if let Some(ref err) = self.source { 50 | write!(f, "Io Error: {}", err) 51 | } else { 52 | write!(f, "Io Error") 53 | } 54 | } 55 | } 56 | } 57 | } 58 | 59 | #[derive(Debug)] 60 | pub(crate) enum Kind { 61 | Encode(Option), 62 | Decode(Option), 63 | Io, 64 | //Status(StatusCode), 65 | } 66 | 67 | impl std::error::Error for ClientError { 68 | fn source(&self) -> Option<&(dyn StdError + 'static)> { 69 | self.source.as_ref().map(|e| &**e as _) 70 | } 71 | } 72 | 73 | pub(crate) fn encode>>(msg: S) -> ClientError { 74 | ClientError::new(Kind::Encode(msg.into()), None::) 75 | } 76 | 77 | pub(crate) fn encode_io>(err: E) -> ClientError { 78 | ClientError::new(Kind::Encode(None), Some(err)) 79 | } 80 | 81 | pub(crate) fn decode>>(msg: S) -> ClientError { 82 | ClientError::new(Kind::Decode(msg.into()), None::) 83 | } 84 | 85 | pub(crate) fn decode_err>(err: E) -> ClientError { 86 | ClientError::new(Kind::Decode(None), Some(err)) 87 | } 88 | 89 | pub(crate) fn io>(err: E) -> ClientError { 90 | ClientError::new(Kind::Io, Some(err)) 91 | } 92 | -------------------------------------------------------------------------------- /examples/server_sent_events_2.rs: -------------------------------------------------------------------------------- 1 | use async_dup::Arc; 2 | use futures::Stream; 3 | use smol::Async; 4 | use std::net::TcpListener; 5 | use std::pin::Pin; 6 | use std::task::{Context, Poll}; 7 | use std::time::Duration; 8 | use tophat::server::accept; 9 | 10 | fn main() -> Result<(), Box> { 11 | tracing_subscriber::fmt::init(); 12 | 13 | let listener = Async::::bind(([127,0,0,1],9999))?; 14 | 15 | smol::block_on(async { 16 | loop { 17 | let (stream, _) = listener.accept().await?; 18 | let stream = Arc::new(stream); 19 | 20 | let task = smol::spawn(async move { 21 | let serve = accept(stream, |_req, mut resp_wtr| async { 22 | let (tx, rx) = async_channel::bounded(100); 23 | let client = Client(rx); 24 | resp_wtr.set_sse(client); 25 | 26 | // a one-shot to send the result of the resp_wtr, so that we can exit the 27 | // endpoint. 28 | let (tx_res, rx_res) = async_channel::bounded(1); 29 | 30 | smol::spawn(async move { 31 | let sse_res = resp_wtr.send().await; 32 | let _ = tx_res.send(sse_res).await; 33 | }) 34 | .detach(); 35 | 36 | let _ = tx.send("data: lorem\n\n".to_owned()).await; 37 | 38 | smol::Timer::after(Duration::from_secs(1)).await; 39 | 40 | let _ = tx.send("data: ipsum\n\n".to_owned()).await; 41 | 42 | // This rx will never receive because the stream will never close. 43 | // 44 | // If the exit from this endpoint was not dependent on the stream closing, 45 | // (i.e. `ResponseWritten` could be constructed by user), then the exit of the 46 | // endoint would drop the tx client, which would close the stream. However, I 47 | // don't think that is idiomatic behavior for an sse, they should be 48 | // long-lived. 49 | rx_res.recv().await.unwrap() 50 | }) 51 | .await; 52 | 53 | if let Err(err) = serve { 54 | eprintln!("Error: {}", err); 55 | } 56 | }); 57 | 58 | task.detach(); 59 | } 60 | }) 61 | } 62 | 63 | struct Client(async_channel::Receiver); 64 | 65 | impl Stream for Client { 66 | type Item = Result; 67 | 68 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 69 | match Pin::new(&mut self.0).poll_next(cx) { 70 | Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))), 71 | Poll::Ready(None) => Poll::Ready(None), 72 | Poll::Pending => Poll::Pending, 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | 8 | jobs: 9 | # rustfmt: 10 | # runs-on: ubuntu-latest 11 | # name: cargo fmt 12 | # steps: 13 | # - uses: actions/checkout@v2 14 | # 15 | # - name: install stable toolchain 16 | # uses: actions-rs/toolchain@v1 17 | # with: 18 | # toolchain: stable 19 | # profile: minimal 20 | # components: rustfmt 21 | # override: true 22 | # 23 | # - name: install rustfmt 24 | # run: rustup component add rustfmt 25 | # 26 | # - name: cargo fmt 27 | # uses: actions-rs/cargo@v1 28 | # with: 29 | # command: fmt 30 | # args: --all -- --check 31 | 32 | test-stable: 33 | runs-on: ${{ matrix.os }} 34 | strategy: 35 | matrix: 36 | os: [macOS-latest, windows-2019, ubuntu-latest] 37 | name: cargo clippy+test 38 | steps: 39 | - uses: actions/checkout@v2 40 | 41 | - name: install stable toolchain 42 | uses: actions-rs/toolchain@v1 43 | with: 44 | toolchain: stable 45 | components: clippy 46 | profile: minimal 47 | override: true 48 | 49 | - name: cargo test 50 | uses: actions-rs/cargo@v1 51 | with: 52 | command: test 53 | # for now, all features are additive 54 | args: --all-features 55 | 56 | test-nightly: 57 | runs-on: ${{ matrix.os }} 58 | strategy: 59 | matrix: 60 | os: [macOS-latest, windows-2019, ubuntu-latest] 61 | name: cargo clippy+test nightly 62 | steps: 63 | - uses: actions/checkout@v2 64 | 65 | - name: install stable toolchain 66 | uses: actions-rs/toolchain@v1 67 | with: 68 | toolchain: nightly 69 | components: clippy 70 | profile: minimal 71 | override: true 72 | 73 | - name: cargo clippy 74 | uses: actions-rs/cargo@v1 75 | with: 76 | command: clippy 77 | args: --all-features -- -D warnings 78 | 79 | - name: cargo test 80 | uses: actions-rs/cargo@v1 81 | with: 82 | command: test 83 | # for now, all features are additive 84 | args: --all-features 85 | 86 | # check-docs: 87 | # name: Docs 88 | # runs-on: ${{ matrix.os }} 89 | # strategy: 90 | # matrix: 91 | # os: [macOS-latest, windows-2019, ubuntu-latest] 92 | # steps: 93 | # - uses: actions/checkout@v2 94 | # 95 | # - name: install nightly toolchain 96 | # uses: actions-rs/toolchain@v1 97 | # with: 98 | # toolchain: nightly 99 | # profile: minimal 100 | # override: true 101 | # 102 | # - name: check docs 103 | # uses: actions-rs/cargo@v1 104 | # with: 105 | # command: doc 106 | # args: --document-private-items 107 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | use futures_lite::{AsyncBufRead, AsyncRead, AsyncWrite}; 2 | use std::io; 3 | use std::pin::Pin; 4 | use std::task::{Context, Poll}; 5 | 6 | pub(crate) struct Cursor { 7 | inner: std::io::Cursor, 8 | } 9 | 10 | impl Cursor { 11 | pub(crate) fn new(t: T) -> Self { 12 | Self { 13 | inner: std::io::Cursor::new(t), 14 | } 15 | } 16 | 17 | #[allow(dead_code)] // used for testing 18 | pub(crate) fn into_inner(self) -> T { 19 | self.inner.into_inner() 20 | } 21 | } 22 | 23 | impl AsyncRead for Cursor 24 | where 25 | T: AsRef<[u8]> + Unpin, 26 | { 27 | fn poll_read( 28 | mut self: Pin<&mut Self>, 29 | _cx: &mut Context<'_>, 30 | buf: &mut [u8], 31 | ) -> Poll> { 32 | Poll::Ready(std::io::Read::read(&mut self.inner, buf)) 33 | } 34 | 35 | fn poll_read_vectored( 36 | mut self: Pin<&mut Self>, 37 | _cx: &mut Context<'_>, 38 | bufs: &mut [std::io::IoSliceMut<'_>], 39 | ) -> Poll> { 40 | Poll::Ready(std::io::Read::read_vectored(&mut self.inner, bufs)) 41 | } 42 | } 43 | 44 | impl AsyncBufRead for Cursor 45 | where 46 | T: AsRef<[u8]> + Unpin, 47 | { 48 | fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { 49 | Poll::Ready(std::io::BufRead::fill_buf(&mut self.get_mut().inner)) 50 | } 51 | 52 | fn consume(mut self: Pin<&mut Self>, amt: usize) { 53 | std::io::BufRead::consume(&mut self.inner, amt) 54 | } 55 | } 56 | 57 | pub(crate) struct Empty; 58 | 59 | pub(crate) fn empty() -> Empty { 60 | Empty 61 | } 62 | 63 | impl AsyncRead for Empty { 64 | fn poll_read( 65 | self: Pin<&mut Self>, 66 | _cx: &mut Context<'_>, 67 | _buf: &mut [u8], 68 | ) -> Poll> { 69 | Poll::Ready(Ok(0)) 70 | } 71 | 72 | fn poll_read_vectored( 73 | self: Pin<&mut Self>, 74 | _cx: &mut Context<'_>, 75 | _bufs: &mut [std::io::IoSliceMut<'_>], 76 | ) -> Poll> { 77 | Poll::Ready(Ok(0)) 78 | } 79 | } 80 | 81 | impl AsyncBufRead for Empty { 82 | fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { 83 | Poll::Ready(Ok(&[])) 84 | } 85 | 86 | fn consume(self: Pin<&mut Self>, _amt: usize) {} 87 | } 88 | 89 | impl AsyncWrite for Cursor> { 90 | fn poll_write( 91 | mut self: Pin<&mut Self>, 92 | _: &mut Context<'_>, 93 | buf: &[u8], 94 | ) -> Poll> { 95 | Poll::Ready(std::io::Write::write(&mut self.inner, buf)) 96 | } 97 | 98 | fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 99 | self.poll_flush(cx) 100 | } 101 | 102 | fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { 103 | Poll::Ready(std::io::Write::flush(&mut self.inner)) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /examples/errors_verbose.rs: -------------------------------------------------------------------------------- 1 | use async_dup::Arc; 2 | use futures_util::io::{AsyncRead, AsyncWrite}; 3 | use http::{Method, StatusCode}; 4 | use smol::Async; 5 | use std::net::TcpListener; 6 | use tophat::{ 7 | server::{ 8 | accept_with_opts, 9 | glitch::{Glitch, GlitchExt, Result}, 10 | router::Router, 11 | ResponseWriter, ResponseWritten, ServerOpts, 12 | }, 13 | Request, 14 | }; 15 | 16 | const S_500: StatusCode = StatusCode::INTERNAL_SERVER_ERROR; 17 | 18 | fn main() -> std::result::Result<(), Box> { 19 | tracing_subscriber::fmt::init(); 20 | 21 | let opts = ServerOpts { 22 | timeout: Some(std::time::Duration::from_secs(60)), 23 | verbose_glitch: true, 24 | }; 25 | 26 | let router = Router::build() 27 | .data("Data from datastore") 28 | .at(Method::GET, "/database_error", database_error) 29 | .at( 30 | Method::GET, 31 | "/database_error_context", 32 | database_error_context, 33 | ) 34 | .at(Method::GET, "/missing_data", missing_data) 35 | .finish(); 36 | 37 | let listener = Async::::bind(([127,0,0,1],9999))?; 38 | 39 | smol::block_on(async { 40 | loop { 41 | let router = router.clone(); 42 | let opts = opts.clone(); 43 | 44 | let (stream, _) = listener.accept().await?; 45 | let stream = Arc::new(stream); 46 | 47 | let task = smol::spawn(async move { 48 | let serve = accept_with_opts(stream, opts, |req, resp_wtr| async { 49 | let res = router.route(req, resp_wtr).await; 50 | res 51 | }) 52 | .await; 53 | 54 | if let Err(err) = serve { 55 | eprintln!("Error: {}", err); 56 | } 57 | }); 58 | 59 | task.detach(); 60 | } 61 | }) 62 | } 63 | 64 | async fn database_error(_req: Request, resp_wtr: ResponseWriter) -> Result 65 | where 66 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 67 | { 68 | use std::io; 69 | 70 | let failed_db: std::result::Result<(), _> = 71 | Err(io::Error::new(io::ErrorKind::Other, "The database crashed")); 72 | failed_db?; 73 | 74 | resp_wtr.send().await 75 | } 76 | 77 | async fn database_error_context( 78 | _req: Request, 79 | resp_wtr: ResponseWriter, 80 | ) -> Result 81 | where 82 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 83 | { 84 | use std::io; 85 | 86 | let failed_db: std::result::Result<(), _> = 87 | Err(io::Error::new(io::ErrorKind::Other, "The database crashed")); 88 | failed_db.glitch_ctx(S_500, "looking for user")?; 89 | 90 | resp_wtr.send().await 91 | } 92 | 93 | async fn missing_data(_req: Request, resp_wtr: ResponseWriter) -> Result 94 | where 95 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 96 | { 97 | let failed_db = None; 98 | 99 | // Manually create a 400 100 | // This will work even without anyhow integration. 101 | failed_db.ok_or_else(|| Glitch::bad_request())?; 102 | 103 | resp_wtr.send().await 104 | } 105 | -------------------------------------------------------------------------------- /examples/middleware.rs: -------------------------------------------------------------------------------- 1 | //! Example of "middleware" with simple cors. 2 | //! 3 | //! It's kind of a do-it-yourself middleware, there's not formal framework for it. It should be 4 | //! easy enough to plug in. 5 | 6 | use async_dup::Arc; 7 | use futures_util::io::{AsyncRead, AsyncWrite}; 8 | use http::Method; 9 | use smol::Async; 10 | use std::net::TcpListener; 11 | use tophat::{ 12 | server::{ 13 | accept, 14 | glitch::Result, 15 | router::{Router, RouterRequestExt}, 16 | ResponseWriter, ResponseWritten, 17 | }, 18 | Request, 19 | }; 20 | 21 | fn main() -> std::result::Result<(), Box> { 22 | tracing_subscriber::fmt::init(); 23 | 24 | let cors = Arc::new(Cors { 25 | allow_origin: "*".to_owned(), 26 | }); 27 | 28 | let router = Router::build() 29 | .data("Data from datastore") 30 | .at(Method::GET, "/:name", hello_user) 31 | .finish(); 32 | 33 | let listener = Async::::bind(([127,0,0,1],9999))?; 34 | 35 | smol::block_on(async { 36 | loop { 37 | let cors = cors.clone(); 38 | let router = router.clone(); 39 | 40 | let (stream, _) = listener.accept().await?; 41 | let stream = Arc::new(stream); 42 | 43 | let task = smol::spawn(async move { 44 | let serve = accept(stream, |req, mut resp_wtr| async { 45 | // Do the middleware thing here 46 | // Cors preflight would require something like 47 | // ``` 48 | // if cors.preflight(&req, &mut resp_wtr) { 49 | // return resp_wtr.send(); 50 | // } 51 | // ``` 52 | cors.simple_cors(&req, &mut resp_wtr); 53 | 54 | // back to routing here 55 | let res = router.route(req, resp_wtr).await; 56 | res 57 | }) 58 | .await; 59 | 60 | if let Err(err) = serve { 61 | eprintln!("Error: {}", err); 62 | } 63 | }); 64 | 65 | task.detach(); 66 | } 67 | }) 68 | } 69 | 70 | async fn hello_user(req: Request, mut resp_wtr: ResponseWriter) -> Result 71 | where 72 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 73 | { 74 | //smol::Timer::after(std::time::Duration::from_secs(5)).await; 75 | 76 | let mut resp_body = format!("Hello, "); 77 | 78 | // add params to body string 79 | if let Some(params) = req.params() { 80 | for (k, v) in params { 81 | resp_body.push_str(&format!("{} = {}", k, v)); 82 | } 83 | } 84 | 85 | // add data to body string 86 | if let Some(data_string) = req.data::<&str>() { 87 | resp_body.push_str(&format!(" and {}", *data_string)); 88 | } 89 | 90 | resp_wtr.set_body(resp_body.into()); 91 | 92 | resp_wtr.send().await 93 | } 94 | 95 | struct Cors { 96 | allow_origin: String, 97 | } 98 | 99 | impl Cors { 100 | // Sets the Access Control Header on the Response of a Responsewriter, if Origin in Request is 101 | // set. 102 | // 103 | // No preflight. 104 | // 105 | // Unless the user changes the header in the endpoint, the header should be sent to the client. 106 | fn simple_cors(&self, req: &Request, resp_wtr: &mut ResponseWriter) 107 | where 108 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 109 | { 110 | if req.headers().get("Origin").is_some() { 111 | resp_wtr.insert_header( 112 | "Access-Control-Allow-Origin", 113 | self.allow_origin.parse().unwrap(), 114 | ); 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /examples/identity.rs: -------------------------------------------------------------------------------- 1 | use async_dup::Arc; 2 | use futures_util::io::{AsyncRead, AsyncWrite}; 3 | use http::Method; 4 | use smol::Async; 5 | use std::net::TcpListener; 6 | use std::time::Duration; 7 | use tophat::{ 8 | server::{ 9 | accept, 10 | glitch::Result, 11 | identity::Identity, 12 | router::{Router, RouterRequestExt}, 13 | ResponseWriter, ResponseWritten, 14 | }, 15 | Request, 16 | }; 17 | 18 | fn main() -> std::result::Result<(), Box> { 19 | tracing_subscriber::fmt::init(); 20 | 21 | let identity = Identity::build("secret_server_key") 22 | .cookie_name("jwt") 23 | .cookie_secure(false) // necessary because example not https 24 | .issuer("tophat") 25 | .expiration_time(Duration::from_secs(30)) 26 | .finish(); 27 | 28 | let router = Router::build() 29 | .data(identity) 30 | .at(Method::GET, "/login/:user", login_user) 31 | .at(Method::GET, "/logout", logout_user) 32 | .at(Method::GET, "/", hello_user) 33 | .finish(); 34 | 35 | let listener = Async::::bind(([127,0,0,1],9999))?; 36 | 37 | smol::block_on(async { 38 | loop { 39 | let router = router.clone(); 40 | 41 | let (stream, _) = listener.accept().await?; 42 | let stream = Arc::new(stream); 43 | 44 | let task = smol::spawn(async move { 45 | let serve = accept(stream, |req, resp_wtr| async { 46 | let res = router.route(req, resp_wtr).await; 47 | res 48 | }) 49 | .await; 50 | 51 | if let Err(err) = serve { 52 | eprintln!("Error: {}", err); 53 | } 54 | }); 55 | 56 | task.detach(); 57 | } 58 | }) 59 | } 60 | 61 | async fn login_user(req: Request, mut resp_wtr: ResponseWriter) -> Result 62 | where 63 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 64 | { 65 | let identity = req.data::().unwrap(); 66 | let user = req.get_param("user").unwrap(); 67 | 68 | // Here, we'll just assume that user is valid. This will usually be a call to the db to check 69 | // against hashed password. 70 | 71 | // Since user is valid, we'll set a cookie with the jwt token 72 | identity.set_auth_token(user, &mut resp_wtr); 73 | 74 | println!("Login req headers{:?}", req.headers()); 75 | println!("Login res headers{:?}", resp_wtr.response().headers()); 76 | 77 | resp_wtr.send().await 78 | } 79 | 80 | async fn logout_user(req: Request, mut resp_wtr: ResponseWriter) -> Result 81 | where 82 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 83 | { 84 | // Since we're using jwt tokens, we don't need to do a check on some session store to remove 85 | // the session; just send the "forget" cookie. 86 | 87 | let identity = req.data::().unwrap(); 88 | 89 | identity.forget(&mut resp_wtr); 90 | 91 | println!("Logout req headers{:?}", req.headers()); 92 | println!("Logout res headers{:?}", resp_wtr.response().headers()); 93 | 94 | resp_wtr.send().await 95 | } 96 | 97 | // Says hello to user based on user login name 98 | async fn hello_user(req: Request, mut resp_wtr: ResponseWriter) -> Result 99 | where 100 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 101 | { 102 | let identity = req.data::().unwrap(); 103 | 104 | println!("Hello req headers{:?}", req.headers()); 105 | 106 | let user = match identity.authorized_user(&req) { 107 | Some(u) => u, 108 | None => { 109 | resp_wtr.set_code(400); 110 | return resp_wtr.send().await; 111 | } 112 | }; 113 | 114 | resp_wtr.set_text(format!("Hello {}", user)); 115 | resp_wtr.send().await 116 | } 117 | -------------------------------------------------------------------------------- /NOTES.md: -------------------------------------------------------------------------------- 1 | # Implementation Notes 2 | Probably going to follow `async-h1` overall structure to keep it simple first. Just try to use the components from hyper, which have fewer deps. 3 | 4 | The inspiration: talking about using Sinks instead of Streams for sending, and how that can affect observability of responses. 5 | https://github.com/hyperium/hyper/issues/2181 6 | https://users.rust-lang.org/t/async-interviews/35167/33 7 | 8 | `hyper/http` and `hyper/http-body` and `http/httpparse` for implementing the http basics 9 | https://github.com/hyperium/http-body/blob/master/src/lib.rs 10 | 11 | Some other implementation examples from async-std: `http-types`, `async-h1` 12 | https://github.com/http-rs/http-types/blob/master/src/body.rs 13 | 14 | Other implementation notes from hyper: `body.rs` 15 | https://docs.rs/hyper/0.13.5/src/hyper/body/body.rs.html#84-87 16 | 17 | # Error handling 18 | Try to handle local fails in each module, bubbling up those failures so they can be handled in the root module. Try to keep the handlers for those failures in each modules also, and output a response, because internal failures should generally be handled by issuing a bad request or internal server error. Catastrophic system failure is a bug. Basically, `accept` should never fail. 19 | 20 | # HTTP RFCs to read 21 | [Message Syntax and Routing](https://tools.ietf.org/html/rfc7230) 22 | - [Message Body Length](https://tools.ietf.org/html/rfc7230#section-3.3.3) 23 | [Original](https://tools.ietf.org/html/rfc2616) 24 | 25 | # URI handling 26 | Looks like https://tools.ietf.org/html/rfc2616#section-5.2 is the section to look at. The question is whether hyper just ignores the host, like the section says is possible? 27 | 28 | Section 19.6.1.1 (requirements for HTTP/1.1 server): 29 | 30 | - server must report 400 if no Host header 31 | - server must accept absolute URI 32 | - https://tools.ietf.org/html/rfc2616#section-19.6.1.1 33 | 34 | absolute URI: https://tools.ietf.org/html/rfc2396#section-3 35 | 36 | absoluteURI is the "whole" url, absolute path is everything after the authority excluding query. 37 | 38 | Check what happens with query strings. 39 | 40 | Looks like hyper just ignores: https://github.com/hyperium/hyper/blob/master/src/proto/h1/role.rs#L102 41 | 42 | ```rust 43 | subject = RequestLine( 44 | Method::from_bytes(req.method.unwrap().as_bytes())?, 45 | req.path.unwrap().parse()?, 46 | ); 47 | ``` 48 | 49 | This lets them just accept absolute paths also. 50 | 51 | async-h1 formats with a scheme and authority onto path, I think this is incorrect. 52 | 53 | # Philosophical notes: 54 | 55 | Designed using language constructs to build your app, instead of creating another layer of abstraction. So using streams and asyncread and write instead of service architecture when possible. The language already gives you tools which are very powerful and composable, so defer to those when possible. 56 | 57 | And instead of services for backend (like timeout and compression) just use async io traits and streams. Just need to provide hooks for them. 58 | 59 | Also, not trying to make easy things appear easy, but making hard things manageable. 60 | 61 | # Error handling 62 | While trying to implement verbose HTTP errors for anyhow for glitch, I realized that using `anyhow::Context` results in a `anyhow::Error` which cannot be converted to `Glitch` because it doesn't implement `std::error::Error` (for coherence reasons). 63 | 64 | Since I only want: 65 | - the ability to add context to errors 66 | - the ability to read the underlying errors 67 | 68 | That means I don't really need something like anyhow; it is powerful, not just because you can convert any error, but because it holds those error's info. Since Glitch will never be doing any more processing of the errors it's converted from (it gets turned directly into a response), it's fine to not have the power of anyhow. 69 | 70 | The solution is to simply remove anyhow and convert all errors to Strings in Glitch. 71 | 72 | The context part is also pretty easy to add. Separately. 73 | 74 | In the future, if users want to use anyhow or write their own framework/router, they can just have their own custom error that will convert to Glitch. (Perhaps they can do something impl Response in Actix?). 75 | 76 | ## More on error handling 77 | I like the idea that Rust is supposed to handle errors on-the-spot, because error handling is not lesser than the "happy path". This means that pushing all of your errors to be handled by a custom catch is not as good as trying to handle as much as possible on the spot. 78 | 79 | There are limits, which is why anyhow is a thing. 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tophat 2 | [![crates.io](https://meritbadge.herokuapp.com/tophat)](https://crates.io/crates/tophat) 3 | [![Released API docs](https://docs.rs/tophat/badge.svg)](https://docs.rs/tophat) 4 | [![CI](https://github.com/hwchen/tophat/workflows/ci/badge.svg)](https://github.com/hwchen/tophat/actions?query=workflow%3Aci) 5 | 6 | A small, pragmatic, and flexible async HTTP server library. Currently in beta. 7 | 8 | Cargo.toml: 9 | ``` 10 | tophat = "0.3.0" 11 | ``` 12 | 13 | The goal is to be low-level and small enough to work with different async runtimes and not dictate user architecture, while having enough convenience functions to still easily build a REST api. More library than framework. 14 | 15 | Also, this: 16 | ```rust 17 | async fn handler(_req: Request, resp_wtr: ResponseWriter) -> Result 18 | where W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 19 | { 20 | // Default `send` is 200 OK 21 | let done = resp_wtr.send()?; 22 | // Do things here after resp is written, if you like 23 | Ok(done) 24 | } 25 | ``` 26 | 27 | instead of: 28 | ```rust 29 | async fn handler(req:Request) -> Result { 30 | Ok(Response::empty()) 31 | } 32 | 33 | ``` 34 | Don't be scared away by the generics and trait bounds! They won't bite! (probably) 35 | 36 | # Features 37 | - HTTP/1.1 38 | - Works with any tcp stream that implements `futures::{AsyncRead, AsyncWrite}`. 39 | - All dependencies are async-ecosystem independent. 40 | - Not meant to be a framework; minimal abstraction. 41 | - #[deny(unsafe_code)] 42 | - Fast enough. 43 | - Router `features = ["router"]`, very minimal. 44 | - Cors `features = ["cors"]`. 45 | - Identity `features = ["identity"]`. 46 | - "Middleware" capabilities by using functions in front of router. 47 | - Convenient error/response handling using `Glitch` and `GlitchExt`, to conveniently chain onto both `Result` and `Option`. 48 | - Extensive examples. 49 | - A minimal client (not under active development) 50 | 51 | Correct handling of the HTTP protocol is a priority. 52 | 53 | # Example 54 | Using [`smol`](https://github.com/stjepang/smol) as the async runtime. Example is single-threaded, see `smol` docs for how to make a multi-threaded executor. 55 | ```rust 56 | use smol::{Async, Task}; 57 | use std::net::TcpListener; 58 | use async_dup::Arc; 59 | use tophat::server::accept; 60 | 61 | fn main() -> Result<(), Box> { 62 | let listener = Async::::bind("127.0.0.1:9999")?; 63 | 64 | smol::block_on(async { 65 | loop { 66 | let (stream, _) = listener.accept().await?; 67 | let stream = Arc::new(stream); 68 | 69 | let task = smol::spawn(async move { 70 | let serve = accept(stream, |_req, resp_wtr| async { 71 | resp_wtr.send().await 72 | }).await; 73 | 74 | if let Err(err) = serve { 75 | eprintln!("Error: {}", err); 76 | } 77 | 78 | }); 79 | 80 | task.detach(); 81 | } 82 | }) 83 | } 84 | ``` 85 | 86 | # Philosophy 87 | 88 | I wouldn't consider this a batteries-included framework which tries to make every step easy. There are conveniences, but overall tophat is pretty minimal. For those who don't like boilerplate, another framework would probably work better. Users of tophat need to be familiar async runtimes, setting up a TCP stream, `Arc`, traits, generics, etc. Tophat won't hold your hand. 89 | 90 | In exchange, tophat provides more transparency and more control. Tophat won't dictate how to structure your app, it should play nicely with your architecture. 91 | 92 | And if you want to know what tophat is doing under the hood, the code is meant to be simple and straightforward (Hopefully this also leads to better compile times!). 93 | 94 | # Inspiration 95 | I was inspired to write tophat because: 96 | - I wanted to have an async https server which was not tied to an async ecosystem and 97 | - I saw this github issue on using a `ResponseWriter` instead of returning `Response`: https://github.com/hyperium/hyper/issues/2181 98 | 99 | # Thanks 100 | Especially to [async-h1](https://github.com/http-rs/async-h1), whose eye for structure and design I appreciate, and whose code base tophat is built from. 101 | And to [hyper](https://github.com/hyperium/hyper), whose devotion to performance and correctness is inspiring, and whose basic http libraries tophat has incorporated. 102 | 103 | # License 104 | 105 | Licensed under either of 106 | 107 | * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) 108 | * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) 109 | 110 | at your option. 111 | -------------------------------------------------------------------------------- /src/server/mod.rs: -------------------------------------------------------------------------------- 1 | #![deny(unsafe_code)] 2 | 3 | //! # tophat server 4 | 5 | #[cfg(feature = "cors")] 6 | pub mod cors; 7 | mod decode; 8 | mod encode; 9 | pub mod glitch; 10 | #[cfg(feature = "identity")] 11 | pub mod identity; 12 | mod response_writer; 13 | #[cfg(feature = "router")] 14 | pub mod router; 15 | pub mod error; 16 | 17 | use futures_lite::{AsyncRead, AsyncWrite, Future}; 18 | use std::time::Duration; 19 | 20 | use crate::body::Body; 21 | use crate::request::Request; 22 | use crate::response::Response; 23 | use crate::server::decode::DecodeFail; 24 | use crate::timeout::{timeout, TimeoutError}; 25 | 26 | use self::decode::decode; 27 | pub use self::error::ServerError; 28 | pub use self::glitch::{Glitch, Result}; 29 | use self::response_writer::InnerResponse; 30 | pub use self::response_writer::{ResponseWriter, ResponseWritten}; 31 | 32 | /// Accept a new incoming Http/1.1 connection 33 | /// 34 | /// Automatically supports KeepAlive 35 | pub async fn accept(io: RW, endpoint: F) -> std::result::Result<(), ServerError> 36 | where 37 | RW: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 38 | F: Fn(Request, ResponseWriter) -> Fut, 39 | Fut: Future>, 40 | { 41 | accept_with_opts(io, ServerOpts::default(), endpoint).await 42 | } 43 | 44 | /// Accept a new incoming Http/1.1 connection 45 | /// 46 | /// Automatically supports KeepAlive 47 | pub async fn accept_with_opts( 48 | io: RW, 49 | opts: ServerOpts, 50 | endpoint: F, 51 | ) -> std::result::Result<(), ServerError> 52 | where 53 | RW: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 54 | F: Fn(Request, ResponseWriter) -> Fut, 55 | Fut: Future>, 56 | { 57 | // All errors should be bubbled up to this fn to handle, either in logs or in responses. 58 | 59 | loop { 60 | // decode to Request 61 | // returns Ok(None) if no request to decode. So no need to exit on a ConnectionLost error. 62 | let req_fut = decode(io.clone()); 63 | 64 | // Handle req failure modes, timeout, eof 65 | let req = if let Some(timeout_duration) = opts.timeout { 66 | // this arm is for with timeout 67 | match timeout(timeout_duration, req_fut).await { 68 | Ok(Ok(Some(r))) => r, 69 | Ok(Ok(None)) | Err(TimeoutError { .. }) => { 70 | //debug!("Timeout Error"); 71 | break; // EOF or timeout 72 | } 73 | Ok(Err(err)) => { 74 | handle_decode_fail(err, io.clone()).await?; 75 | // and continue on to next request 76 | continue; 77 | } 78 | } 79 | } else { 80 | // This arm is for no timeout 81 | match req_fut.await { 82 | Ok(Some(r)) => r, 83 | Ok(None) => break, // EOF 84 | Err(err) => { 85 | handle_decode_fail(err, io.clone()).await?; 86 | // and continue on to next request 87 | continue; 88 | } 89 | } 90 | }; 91 | 92 | let resp_wtr = ResponseWriter { 93 | writer: io.clone(), 94 | response: Response::new(Body::empty()), 95 | }; 96 | if let Err(glitch) = endpoint(req, resp_wtr).await { 97 | let _ = glitch 98 | .into_inner_response(opts.verbose_glitch) 99 | .send(io.clone()) 100 | .await; 101 | } 102 | } 103 | 104 | Ok(()) 105 | } 106 | 107 | /// Options for the tophat server. 108 | #[derive(Clone)] 109 | pub struct ServerOpts { 110 | /// Connection timeout (in seconds) 111 | pub timeout: Option, 112 | /// Option to send error (from convertin error to Glitch) traces in an error response (Glitch) 113 | pub verbose_glitch: bool, 114 | } 115 | 116 | impl Default for ServerOpts { 117 | fn default() -> Self { 118 | Self { 119 | timeout: Some(Duration::from_secs(60)), 120 | verbose_glitch: false, 121 | } 122 | } 123 | } 124 | 125 | // handles both writing error response and bubbling up major system errors as necessary. 126 | async fn handle_decode_fail(fail: DecodeFail, io: RW) -> std::result::Result<(), ServerError> 127 | where 128 | RW: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 129 | { 130 | // send a resp for errors from decoding 131 | if let Some(err_resp) = decode::fail_to_response_and_log(&fail) { 132 | let _ = err_resp.send(io.clone()).await; 133 | } 134 | // Early return if there's a major error. 135 | if let Some(crate_err) = decode::fail_to_crate_err(fail) { 136 | return Err(crate_err); 137 | } 138 | 139 | Ok(()) 140 | } 141 | -------------------------------------------------------------------------------- /src/client/decode.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::nonminimal_bool)] 2 | #![allow(clippy::op_ref)] 3 | 4 | use futures_lite::{io::BufReader, AsyncBufReadExt, AsyncRead, AsyncReadExt}; 5 | use http::{ 6 | header::{HeaderMap, HeaderName, HeaderValue, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}, 7 | StatusCode, 8 | }; 9 | use httpdate::fmt_http_date; 10 | 11 | use super::error::{self, ClientError}; 12 | use crate::chunked::ChunkedDecoder; 13 | use crate::{Body, Response}; 14 | 15 | const CR: u8 = b'\r'; 16 | const LF: u8 = b'\n'; 17 | const MAX_HEADERS: usize = 128; 18 | const MAX_HEAD_LENGTH: usize = 8 * 1024; 19 | 20 | /// Decode an HTTP response on the client. 21 | #[doc(hidden)] 22 | pub async fn decode(reader: R) -> Result 23 | where 24 | R: AsyncRead + Unpin + Send + Sync + 'static, 25 | { 26 | let mut reader = BufReader::new(reader); 27 | let mut buf = Vec::new(); 28 | let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; 29 | let mut httparse_res = httparse::Response::new(&mut headers); 30 | 31 | // Keep reading bytes from the stream until we hit the end of the stream. 32 | loop { 33 | let bytes_read = reader 34 | .read_until(LF, &mut buf) 35 | .await 36 | .map_err(error::decode_err)?; 37 | // No more bytes are yielded from the stream. 38 | if bytes_read == 0 { 39 | return Err(error::decode("Empty response".to_owned())); 40 | } 41 | 42 | // Prevent CWE-400 DDOS with large HTTP Headers. 43 | if !(buf.len() < MAX_HEAD_LENGTH) { 44 | return Err(error::decode( 45 | "Head byte length should be less than 8kb".to_owned(), 46 | )); 47 | }; 48 | 49 | // We've hit the end delimiter of the stream. 50 | let idx = buf.len() - 1; 51 | if idx >= 3 && &buf[idx - 3..=idx] == [CR, LF, CR, LF] { 52 | break; 53 | } 54 | if idx >= 1 && &buf[idx - 1..=idx] == [LF, LF] { 55 | break; 56 | } 57 | } 58 | 59 | // Convert our header buf into an httparse instance, and validate. 60 | let status = httparse_res.parse(&buf).map_err(error::decode_err)?; 61 | if status.is_partial() { 62 | return Err(error::decode("Malformed HTTP head".to_owned())); 63 | }; 64 | 65 | let code = httparse_res.code; 66 | let code = code.ok_or_else(|| error::decode("No status code found".to_owned()))?; 67 | 68 | // Convert httparse headers + body into a `http_types::Response` type. 69 | let version = httparse_res.version; 70 | let version = version.ok_or_else(|| error::decode("No version found".to_owned()))?; 71 | if version != 1 { 72 | return Err(error::decode("Unsupported HTTP version".to_owned())); 73 | }; 74 | 75 | let mut headers = HeaderMap::new(); 76 | for header in httparse_res.headers.iter() { 77 | let value = HeaderValue::from_bytes(header.value).map_err(error::decode_err)?; 78 | let name: HeaderName = header.name.parse().map_err(error::decode_err)?; 79 | headers.append(name, value); 80 | } 81 | 82 | if headers.get(DATE).is_none() { 83 | let date = fmt_http_date(std::time::SystemTime::now()); 84 | let value = HeaderValue::from_str(&date).map_err(error::decode_err)?; 85 | headers.insert(DATE, value); 86 | } 87 | 88 | let content_length = headers.get(CONTENT_LENGTH); 89 | let transfer_encoding = headers.get(TRANSFER_ENCODING); 90 | 91 | if !(content_length.is_none() || transfer_encoding.is_none()) { 92 | return Err(error::decode("Unexpected Content-Length header".to_owned())); 93 | }; 94 | 95 | // must be either transfer encoding or content length/ TODO compile time 96 | let mut res = Response::new(Body::empty()); 97 | 98 | if let Some(encoding) = headers.get(TRANSFER_ENCODING).iter().last() { 99 | if *encoding == "chunked" { 100 | let mut body = Body::empty(); 101 | let trailers_sender = body.send_trailers(); 102 | let reader = BufReader::new(ChunkedDecoder::new(reader, trailers_sender)); 103 | body.set_inner(reader, None); 104 | *res.body_mut() = body; 105 | 106 | // Return the response. 107 | return Ok(res); 108 | } 109 | } 110 | 111 | // Check for Content-Length. 112 | if let Some(len) = headers.get(CONTENT_LENGTH).iter().last() { 113 | let len = len 114 | .to_str() 115 | .map_err(error::decode_err)? 116 | .parse::() 117 | .map_err(error::decode_err)?; 118 | res = Response::new(Body::from_reader(reader.take(len as u64), Some(len))); 119 | } 120 | 121 | *res.status_mut() = StatusCode::from_u16(code).map_err(error::decode_err)?; 122 | 123 | *res.headers_mut() = headers; 124 | 125 | // Return the response. 126 | Ok(res) 127 | } 128 | -------------------------------------------------------------------------------- /src/client/encode.rs: -------------------------------------------------------------------------------- 1 | use futures_lite::{AsyncRead, AsyncWriteExt}; 2 | use http::{header::HOST, Method, Request}; 3 | use std::io; 4 | use std::pin::Pin; 5 | use std::task::{Context, Poll}; 6 | use tracing::trace; 7 | 8 | use super::{error, ClientError}; 9 | use crate::Body; 10 | 11 | /// An HTTP encoder. 12 | #[doc(hidden)] 13 | pub struct Encoder { 14 | /// Keep track how far we've indexed into the headers + body. 15 | cursor: usize, 16 | /// HTTP headers to be sent. 17 | headers: Vec, 18 | /// Check whether we're done sending headers. 19 | headers_done: bool, 20 | /// HTTP body to be sent. 21 | body: Body, 22 | /// Check whether we're done with the body. 23 | body_done: bool, 24 | /// Keep track of how many bytes have been read from the body stream. 25 | body_bytes_read: usize, 26 | } 27 | 28 | impl Encoder { 29 | /// Encode an HTTP request on the client. 30 | pub async fn encode(req: Request) -> Result { 31 | let mut buf: Vec = vec![]; 32 | 33 | // clients are not supposed to send uri frags when retrieving a document 34 | // removed code for that here, skip to query. 35 | let mut url = req.uri().path().to_owned(); 36 | if let Some(query) = req.uri().query() { 37 | url.push('?'); 38 | url.push_str(query); 39 | } 40 | 41 | // A client sending a CONNECT request MUST consists of only the host 42 | // name and port number of the tunnel destination, separated by a colon. 43 | // See: https://tools.ietf.org/html/rfc7231#section-4.3.6 44 | if req.method() == Method::CONNECT { 45 | let host = req.uri().host(); 46 | let host = host.ok_or_else(|| error::encode("Missing hostname".to_owned()))?; 47 | let port = req.uri().port(); // or known default? 48 | let port = port.ok_or_else(|| error::encode("Missing port".to_owned()))?; 49 | url = format!("{}:{}", host, port); 50 | } 51 | 52 | let val = format!("{} {} HTTP/1.1\r\n", req.method(), url); 53 | trace!("> {}", &val); 54 | buf.write_all(val.as_bytes()) 55 | .await 56 | .map_err(error::encode_io)?; 57 | 58 | if req.headers().get(HOST).is_none() { 59 | // Insert Host header 60 | // Insert host 61 | let host = req.uri().host(); 62 | let host = host.ok_or_else(|| error::encode("Missing hostname".to_owned()))?; 63 | let val = if let Some(port) = req.uri().port() { 64 | format!("host: {}:{}\r\n", host, port) 65 | } else { 66 | format!("host: {}\r\n", host) 67 | }; 68 | 69 | trace!("> {}", &val); 70 | buf.write_all(val.as_bytes()) 71 | .await 72 | .map_err(error::encode_io)?; 73 | } 74 | 75 | // Insert Proxy-Connection header when method is CONNECT 76 | if req.method() == Method::CONNECT { 77 | let val = "proxy-connection: keep-alive\r\n".to_owned(); 78 | trace!("> {}", &val); 79 | buf.write_all(val.as_bytes()) 80 | .await 81 | .map_err(error::encode_io)?; 82 | } 83 | 84 | // If the body isn't streaming, we can set the content-length ahead of time. Else we need to 85 | // send all items in chunks. 86 | if let Some(len) = req.body().length { 87 | let val = format!("content-length: {}\r\n", len); 88 | trace!("> {}", &val); 89 | buf.write_all(val.as_bytes()) 90 | .await 91 | .map_err(error::encode_io)?; 92 | } else { 93 | // write!(&mut buf, "Transfer-Encoding: chunked\r\n")?; 94 | panic!("chunked encoding is not implemented yet"); 95 | // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding 96 | // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Trailer 97 | } 98 | 99 | for (header, value) in req.headers().iter() { 100 | buf.write_all(header.as_str().as_bytes()) 101 | .await 102 | .map_err(error::encode_io)?; 103 | buf.write_all(b": ").await.map_err(error::encode_io)?; 104 | buf.write_all(value.as_bytes()) 105 | .await 106 | .map_err(error::encode_io)?; 107 | buf.write_all(b"\r\n").await.map_err(error::encode_io)?; 108 | } 109 | 110 | buf.write_all(b"\r\n").await.map_err(error::encode_io)?; 111 | 112 | Ok(Self { 113 | body: req.into_body(), 114 | headers: buf, 115 | cursor: 0, 116 | headers_done: false, 117 | body_done: false, 118 | body_bytes_read: 0, 119 | }) 120 | } 121 | } 122 | 123 | impl AsyncRead for Encoder { 124 | fn poll_read( 125 | mut self: Pin<&mut Self>, 126 | cx: &mut Context<'_>, 127 | buf: &mut [u8], 128 | ) -> Poll> { 129 | // Send the headers. As long as the headers aren't fully sent yet we 130 | // keep sending more of the headers. 131 | let mut bytes_read = 0; 132 | if !self.headers_done { 133 | let len = std::cmp::min(self.headers.len() - self.cursor, buf.len()); 134 | let range = self.cursor..self.cursor + len; 135 | buf[0..len].copy_from_slice(&self.headers[range]); 136 | self.cursor += len; 137 | if self.cursor == self.headers.len() { 138 | self.headers_done = true; 139 | } 140 | bytes_read += len; 141 | } 142 | 143 | if !self.body_done { 144 | let inner_poll_result = Pin::new(&mut self.body).poll_read(cx, &mut buf[bytes_read..]); 145 | let n = match inner_poll_result { 146 | Poll::Ready(Ok(n)) => n, 147 | Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), 148 | Poll::Pending => { 149 | if bytes_read == 0 { 150 | return Poll::Pending; 151 | } else { 152 | return Poll::Ready(Ok(bytes_read as usize)); 153 | } 154 | } 155 | }; 156 | bytes_read += n; 157 | self.body_bytes_read += n; 158 | if bytes_read == 0 { 159 | self.body_done = true; 160 | } 161 | } 162 | 163 | Poll::Ready(Ok(bytes_read as usize)) 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /src/body.rs: -------------------------------------------------------------------------------- 1 | use futures_lite::{AsyncBufRead, AsyncRead, AsyncReadExt}; 2 | use std::io; 3 | use std::pin::Pin; 4 | use std::task::{Context, Poll}; 5 | 6 | use crate::trailers::{Trailers, TrailersSender}; 7 | use crate::util::{empty, Cursor}; 8 | use self::error::BodyError; 9 | 10 | pin_project_lite::pin_project! { 11 | /// A streaming body for use with requests and responses. 12 | /// 13 | /// includes many convenience methods for converting to and from body 14 | pub struct Body { 15 | #[pin] 16 | pub(crate) reader: Box, 17 | pub(crate) length: Option, 18 | trailer_sender: Option>>, 19 | trailer_receiver: async_channel::Receiver>, 20 | } 21 | } 22 | 23 | impl Body { 24 | /// Create an empty Body 25 | pub fn empty() -> Self { 26 | let (sender, receiver) = async_channel::bounded(1); 27 | 28 | Self { 29 | reader: Box::new(empty()), 30 | length: Some(0), 31 | trailer_sender: Some(sender), 32 | trailer_receiver: receiver, 33 | } 34 | } 35 | 36 | /// Create a Body from a typ implementing AsyncRead 37 | /// 38 | /// if len: None will result in Transfer-Encoding: chunked 39 | /// if len: Some(n) will result in fixed body 40 | pub fn from_reader( 41 | reader: impl AsyncBufRead + Unpin + Send + Sync + 'static, 42 | len: Option, 43 | ) -> Self { 44 | let (sender, receiver) = async_channel::bounded(1); 45 | 46 | Self { 47 | reader: Box::new(reader), 48 | length: len, 49 | trailer_sender: Some(sender), 50 | trailer_receiver: receiver, 51 | } 52 | } 53 | 54 | /// Create a Body from bytes 55 | pub fn from_bytes(bytes: Vec) -> Self { 56 | let (sender, receiver) = async_channel::bounded(1); 57 | 58 | Self { 59 | length: Some(bytes.len()), 60 | reader: Box::new(Cursor::new(bytes)), 61 | trailer_sender: Some(sender), 62 | trailer_receiver: receiver, 63 | } 64 | } 65 | 66 | /// Read a Body into bytes. Consumes Body. 67 | pub async fn into_bytes(mut self) -> Result, BodyError> { 68 | let mut buf = Vec::with_capacity(1024); 69 | self.read_to_end(&mut buf) 70 | .await 71 | .map_err(BodyError::Conversion)?; 72 | Ok(buf) 73 | } 74 | 75 | /// Read a Body into a String. Consumes Body. 76 | pub async fn into_string(mut self) -> Result { 77 | let mut buf = String::with_capacity(self.length.unwrap_or(0)); 78 | self.read_to_string(&mut buf) 79 | .await 80 | .map_err(BodyError::Conversion)?; 81 | Ok(buf) 82 | } 83 | 84 | /// sending trailers not yet supported 85 | pub async fn into_bytes_with_trailer( 86 | mut self, 87 | ) -> Result<(Vec, Option>), BodyError> { 88 | let mut buf = Vec::with_capacity(1024); 89 | self.read_to_end(&mut buf) 90 | .await 91 | .map_err(BodyError::Conversion)?; 92 | let trailer = self.recv_trailers().await; 93 | Ok((buf, trailer)) 94 | } 95 | 96 | /// sending trailers not yet supported 97 | pub async fn into_string_with_trailer( 98 | mut self, 99 | ) -> Result<(String, Option>), BodyError> { 100 | let mut buf = String::with_capacity(self.length.unwrap_or(0)); 101 | self.read_to_string(&mut buf) 102 | .await 103 | .map_err(BodyError::Conversion)?; 104 | let trailer = self.recv_trailers().await; 105 | Ok((buf, trailer)) 106 | } 107 | 108 | /// sending trailers not yet supported 109 | pub fn send_trailers(&mut self) -> TrailersSender { 110 | let sender = self 111 | .trailer_sender 112 | .take() 113 | .expect("Trailers sender can only be constructed once"); 114 | TrailersSender::new(sender) 115 | } 116 | 117 | /// Don't use this directly if you also want to read the body. 118 | /// In that case, prefer `into_{bytes, string}_with_trailer() 119 | pub async fn recv_trailers(&self) -> Option> { 120 | self.trailer_receiver.recv().await.ok() 121 | } 122 | 123 | pub(crate) fn set_inner( 124 | &mut self, 125 | rdr: impl AsyncBufRead + Unpin + Send + Sync + 'static, 126 | len: Option, 127 | ) { 128 | self.reader = Box::new(rdr); 129 | self.length = len; 130 | } 131 | } 132 | 133 | impl From for Body { 134 | fn from(s: String) -> Self { 135 | let (sender, receiver) = async_channel::bounded(1); 136 | 137 | Self { 138 | length: Some(s.len()), 139 | reader: Box::new(Cursor::new(s.into_bytes())), 140 | trailer_sender: Some(sender), 141 | trailer_receiver: receiver, 142 | } 143 | } 144 | } 145 | 146 | impl<'a> From<&'a str> for Body { 147 | fn from(s: &'a str) -> Self { 148 | let (sender, receiver) = async_channel::bounded(1); 149 | 150 | Self { 151 | length: Some(s.len()), 152 | reader: Box::new(Cursor::new(s.to_owned().into_bytes())), 153 | trailer_sender: Some(sender), 154 | trailer_receiver: receiver, 155 | } 156 | } 157 | } 158 | 159 | impl AsyncRead for Body { 160 | fn poll_read( 161 | mut self: Pin<&mut Self>, 162 | cx: &mut Context<'_>, 163 | buf: &mut [u8], 164 | ) -> Poll> { 165 | Pin::new(&mut self.reader).poll_read(cx, buf) 166 | } 167 | } 168 | 169 | impl AsyncBufRead for Body { 170 | fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 171 | let this = self.project(); 172 | this.reader.poll_fill_buf(cx) 173 | } 174 | 175 | fn consume(mut self: Pin<&mut Self>, amt: usize) { 176 | Pin::new(&mut self.reader).consume(amt) 177 | } 178 | } 179 | 180 | pub mod error { 181 | use std::fmt; 182 | 183 | /// Error for Body Type 184 | #[derive(Debug)] 185 | pub enum BodyError { 186 | /// Error when converting from a type to Body 187 | Conversion(std::io::Error), 188 | /// Error for sending or receiving trailer 189 | Trailer(std::io::Error), 190 | } 191 | 192 | impl std::error::Error for BodyError { 193 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 194 | use BodyError::*; 195 | match self { 196 | Conversion(err) => Some(err), 197 | Trailer(err) => Some(err), 198 | } 199 | } 200 | } 201 | 202 | impl fmt::Display for BodyError { 203 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 204 | use BodyError::*; 205 | match self { 206 | Conversion(err) => write!(f, "Error converting body: {}", err), 207 | Trailer(err) => write!(f, "Error in body trailer: {}", err), 208 | } 209 | } 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /tests/server_error_handling.rs: -------------------------------------------------------------------------------- 1 | mod mock; 2 | 3 | use tophat::{ 4 | glitch, glitch_code, 5 | http::StatusCode, 6 | server::{ 7 | accept, 8 | glitch::{Glitch, GlitchExt}, 9 | }, 10 | }; 11 | 12 | use mock::Client; 13 | 14 | const RESP_400: &str = "HTTP/1.1 400 Bad Request\r\ncontent-length: 0\r\n\r\n"; 15 | const RESP_500: &str = "HTTP/1.1 500 Internal Server Error\r\ncontent-length: 0\r\n\r\n"; 16 | const S_400: StatusCode = StatusCode::BAD_REQUEST; 17 | const S_500: StatusCode = StatusCode::INTERNAL_SERVER_ERROR; 18 | 19 | #[test] 20 | fn test_request_manually_create_glitch() { 21 | smol::block_on(async { 22 | let testclient = Client::new( 23 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 24 | RESP_400, 25 | ); 26 | 27 | accept(testclient.clone(), |_req, resp_wtr| async move { 28 | "one".parse::().map_err(|_| Glitch::bad_request())?; 29 | let done = resp_wtr.send().await.unwrap(); 30 | 31 | Ok(done) 32 | }) 33 | .await 34 | .unwrap(); 35 | 36 | testclient.assert(); 37 | }); 38 | } 39 | 40 | #[test] 41 | fn test_request_glitch_with_context() { 42 | // one test to see that just `?` works, and another to see that manual Glitch creation still 43 | // works even with anyhow enabled. 44 | 45 | // automatic 46 | smol::block_on(async { 47 | let testclient = Client::new( 48 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 49 | RESP_500, 50 | ); 51 | 52 | accept(testclient.clone(), |_req, resp_wtr| async move { 53 | "one".parse::()?; 54 | let done = resp_wtr.send().await.unwrap(); 55 | 56 | Ok(done) 57 | }) 58 | .await 59 | .unwrap(); 60 | 61 | testclient.assert(); 62 | }); 63 | 64 | // context no message 65 | smol::block_on(async { 66 | let testclient = Client::new( 67 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 68 | RESP_500, 69 | ); 70 | 71 | accept(testclient.clone(), |_req, resp_wtr| async move { 72 | "one".parse::().glitch(S_500)?; 73 | let done = resp_wtr.send().await.unwrap(); 74 | 75 | Ok(done) 76 | }) 77 | .await 78 | .unwrap(); 79 | 80 | testclient.assert(); 81 | }); 82 | 83 | // context 84 | smol::block_on(async { 85 | let testclient = Client::new( 86 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 87 | "HTTP/1.1 500 Internal Server Error\r\ncontent-length: 12\r\ncontent-type: text/plain\r\n\r\ncustom error", 88 | ); 89 | 90 | accept(testclient.clone(), |_req, resp_wtr| async move { 91 | "one".parse::().glitch_ctx(S_500, "custom error")?; 92 | let done = resp_wtr.send().await.unwrap(); 93 | 94 | Ok(done) 95 | }) 96 | .await 97 | .unwrap(); 98 | 99 | testclient.assert(); 100 | }); 101 | 102 | // context on Option 103 | smol::block_on(async { 104 | let testclient = Client::new( 105 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 106 | "HTTP/1.1 400 Bad Request\r\ncontent-length: 12\r\ncontent-type: text/plain\r\n\r\ncustom error", 107 | ); 108 | 109 | accept(testclient.clone(), |_req, resp_wtr| async move { 110 | let usr = None; 111 | usr.glitch_ctx(S_400, "custom error")?; 112 | let done = resp_wtr.send().await.unwrap(); 113 | 114 | Ok(done) 115 | }) 116 | .await 117 | .unwrap(); 118 | 119 | testclient.assert(); 120 | }); 121 | 122 | // manual 123 | smol::block_on(async { 124 | let testclient = Client::new( 125 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 126 | RESP_400, 127 | ); 128 | 129 | accept(testclient.clone(), |_req, resp_wtr| async move { 130 | "one".parse::().map_err(|_| Glitch::bad_request())?; 131 | let done = resp_wtr.send().await.unwrap(); 132 | 133 | Ok(done) 134 | }) 135 | .await 136 | .unwrap(); 137 | 138 | testclient.assert(); 139 | }); 140 | } 141 | 142 | #[test] 143 | fn test_request_glitch_macro() { 144 | smol::block_on(async { 145 | let testclient = Client::new( 146 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 147 | RESP_500, 148 | ); 149 | 150 | accept(testclient.clone(), |_req, resp_wtr| async move { 151 | "one".parse::().map_err(|_| glitch!())?; 152 | let done = resp_wtr.send().await.unwrap(); 153 | 154 | Ok(done) 155 | }) 156 | .await 157 | .unwrap(); 158 | 159 | testclient.assert(); 160 | }); 161 | smol::block_on(async { 162 | let testclient = Client::new( 163 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 164 | RESP_400, 165 | ); 166 | 167 | accept(testclient.clone(), |_req, resp_wtr| async move { 168 | "one" 169 | .parse::() 170 | .map_err(|_| glitch!(StatusCode::BAD_REQUEST))?; 171 | let done = resp_wtr.send().await.unwrap(); 172 | 173 | Ok(done) 174 | }) 175 | .await 176 | .unwrap(); 177 | 178 | testclient.assert(); 179 | }); 180 | smol::block_on(async { 181 | let testclient = Client::new( 182 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 183 | "HTTP/1.1 400 Bad Request\r\ncontent-length: 12\r\ncontent-type: text/plain\r\n\r\ncustom error", 184 | ); 185 | 186 | accept(testclient.clone(), |_req, resp_wtr| async move { 187 | "one" 188 | .parse::() 189 | .map_err(|_| glitch!(StatusCode::BAD_REQUEST, "custom error"))?; 190 | let done = resp_wtr.send().await.unwrap(); 191 | 192 | Ok(done) 193 | }) 194 | .await 195 | .unwrap(); 196 | 197 | testclient.assert(); 198 | }); 199 | } 200 | 201 | #[test] 202 | fn test_request_glitch_code_macro() { 203 | // this one can panic if code incorrect 204 | smol::block_on(async { 205 | let testclient = Client::new( 206 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 207 | "HTTP/1.1 400 Bad Request\r\ncontent-length: 12\r\ncontent-type: text/plain\r\n\r\ncustom error", 208 | ); 209 | 210 | accept(testclient.clone(), |_req, resp_wtr| async move { 211 | "one" 212 | .parse::() 213 | .map_err(|_| glitch_code!(400, "custom error"))?; 214 | let done = resp_wtr.send().await.unwrap(); 215 | 216 | Ok(done) 217 | }) 218 | .await 219 | .unwrap(); 220 | 221 | testclient.assert(); 222 | }); 223 | } 224 | 225 | #[test] 226 | #[should_panic] 227 | fn test_request_glitch_code_macro_panic() { 228 | // this one can panic if code incorrect 229 | smol::block_on(async { 230 | let testclient = Client::new( 231 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 232 | RESP_400, 233 | ); 234 | 235 | accept(testclient.clone(), |_req, resp_wtr| async move { 236 | "one" 237 | .parse::() 238 | .map_err(|_| glitch_code!(1, "custom error"))?; 239 | let done = resp_wtr.send().await.unwrap(); 240 | 241 | Ok(done) 242 | }) 243 | .await 244 | .unwrap(); 245 | 246 | testclient.assert(); 247 | }); 248 | } 249 | -------------------------------------------------------------------------------- /src/server/router.rs: -------------------------------------------------------------------------------- 1 | // Thanks to tide, looking at Endpoint helped me understand how to coerce Fn to be storable. And to 2 | // reset-router, to understand how to use extensions and a trait to allow access from a Request. 3 | // 4 | // I'm a little nervous, because this router is so much more direct than others I've seen. I think 5 | // it's because I'm not integrating with a Service, and I'm not making middleware. 6 | // 7 | // My main concern was whether there would be such as thing as contention for the function. 8 | // - looks like there's no issue when there's a sleep timer in the fn, the number of connections 9 | // waiting on the sleep still scales with the number of total connections made. 10 | // - Oh, I should test with one autocannon set to a sleep endpoint and another to not. 11 | // - I'm able to run autocannon on `hello_` and simultaneously curl with another user, and it 12 | // goes through just fine. So there's no contention, it's just the sleep. 13 | // - So basically, I think that my code should be fine, the fn's code is just a reference, and 14 | // then runtime-code gets filled in (the appropriate params and stuff). 15 | 16 | //! Very Basic router 17 | //! 18 | //! - basic routing (no nesting) 19 | //! - holds global data 20 | //! - no extractors (you've got to find all the stuff you want attached to the `Request`) 21 | 22 | use crate::server::{Request, ResponseWriter, ResponseWritten, Result}; 23 | use async_dup::Arc; 24 | use futures_util::io::{AsyncRead, AsyncWrite}; 25 | use http::{Method, StatusCode}; 26 | use path_tree::PathTree; 27 | use std::future::Future; 28 | use std::pin::Pin; 29 | 30 | /// Convenience type for params. 31 | /// 32 | /// A `Vec` of (param_name, captured_value) 33 | pub type Params = Vec<(String, String)>; 34 | 35 | /// A minimal router 36 | #[derive(Clone)] 37 | pub struct Router 38 | where 39 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 40 | { 41 | tree: Arc>>>, 42 | data: Arc>, 43 | } 44 | 45 | impl Router 46 | where 47 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 48 | { 49 | /// Build a router 50 | pub fn build() -> RouterBuilder { 51 | RouterBuilder::new() 52 | } 53 | 54 | /// Call this to route a request 55 | pub async fn route( 56 | &self, 57 | mut req: Request, 58 | mut resp_wtr: ResponseWriter, 59 | ) -> Result { 60 | let path = "/".to_owned() + req.method().as_str() + req.uri().path(); 61 | 62 | match self.tree.find(&path) { 63 | Some((endpoint, params)) => { 64 | let params: Vec<(String, String)> = params 65 | .into_iter() 66 | .map(|(a, b)| (a.to_owned(), b.to_owned())) 67 | .collect(); 68 | 69 | // a place to store data and params 70 | // extensions is a type map, and then 71 | // data is also a type map. 72 | let extensions_mut = req.extensions_mut(); 73 | if let Some(ref data) = *self.data { 74 | extensions_mut.insert(data.clone()); 75 | } 76 | extensions_mut.insert(params); 77 | 78 | endpoint.call(req, resp_wtr).await 79 | } 80 | None => { 81 | resp_wtr.set_status(StatusCode::NOT_FOUND); 82 | resp_wtr.send().await 83 | } 84 | } 85 | } 86 | } 87 | 88 | /// Build a router 89 | pub struct RouterBuilder 90 | where 91 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 92 | { 93 | tree: PathTree>>, 94 | data: Option, 95 | } 96 | 97 | impl RouterBuilder 98 | where 99 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 100 | { 101 | fn new() -> Self { 102 | Self { 103 | tree: PathTree::new(), 104 | data: None, 105 | } 106 | } 107 | 108 | /// Attach a route with: method, path, endpoint. 109 | /// 110 | /// For the path, you can use: 111 | /// - Named parameters. e.g. :name. 112 | /// - Catch-All parameters. e.g. *any, it must always be at the end of the pattern. 113 | /// - Supports multiple naming for the same path segment. e.g. /users/:id and /users/:user_id/repos. 114 | /// - Don't care about routes orders, recursive lookup, Static -> Named -> Catch-All. 115 | /// (path-tree is used as the underlying router) 116 | pub fn at(self, method: Method, path: &str, endpoint: impl Endpoint) -> Self { 117 | let mut this = self; 118 | 119 | let path = "/".to_owned() + method.as_str() + path; 120 | 121 | this.tree.insert(&path, Box::new(endpoint)); 122 | this 123 | } 124 | 125 | /// Add data of type `T` to the router, to be accessed later through the request as 126 | /// `req.data()`. Data is stored in a typemap. 127 | /// 128 | /// Requires `RouterRequestExt`. 129 | pub fn data(self, data: T) -> Self { 130 | self.wrapped_data(Data::new(data)) 131 | } 132 | 133 | /// Add data of type `Data` to the router, to be accessed later through the request as 134 | /// `req.data()`. Data is stored in a typemap. 135 | /// 136 | /// Requires `RouterRequestExt`. 137 | pub fn wrapped_data(mut self, data: T) -> Self { 138 | let mut map = self 139 | .data 140 | .take() 141 | .unwrap_or_else(type_map::concurrent::TypeMap::new); 142 | map.insert(data); 143 | self.data = Some(map); 144 | self 145 | } 146 | 147 | /// Finish building router 148 | pub fn finish(self) -> Router { 149 | Router { 150 | tree: Arc::new(self.tree), 151 | data: Arc::new(self.data.map(Data::new).map(DataMap)), 152 | } 153 | } 154 | } 155 | 156 | /// A trait for all endpoints, so that the user can just use any suitable closure or fn in the 157 | /// method for building a router. 158 | pub trait Endpoint: Send + Sync + 'static 159 | where 160 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 161 | { 162 | /// Invoke the endpoint within the given context 163 | fn call( 164 | &self, 165 | req: Request, 166 | resp_wtr: ResponseWriter, 167 | ) -> BoxFuture>; 168 | } 169 | 170 | impl Endpoint for F 171 | where 172 | F: Fn(Request, ResponseWriter) -> Fut, 173 | Fut: Future> + Send + 'static, 174 | Res: Into, 175 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 176 | { 177 | fn call( 178 | &self, 179 | req: Request, 180 | resp: ResponseWriter, 181 | ) -> BoxFuture> { 182 | let fut = (self)(req, resp); 183 | Box::pin(async move { 184 | let res = fut.await?; 185 | Ok(res.into()) 186 | }) 187 | } 188 | } 189 | 190 | // Router extras: Data and params access 191 | 192 | /// Data type for wrapping data for access within an endpoint 193 | pub struct Data(Arc); 194 | 195 | impl Data { 196 | /// Make a Data 197 | pub fn new(t: T) -> Self { 198 | Data(Arc::new(t)) 199 | } 200 | 201 | /// Make a Data from data which is wrapped in an Arc 202 | pub fn from_arc(arc: Arc) -> Self { 203 | Data(arc) 204 | } 205 | } 206 | 207 | impl std::ops::Deref for Data { 208 | type Target = T; 209 | 210 | fn deref(&self) -> &Self::Target { 211 | &*self.0 212 | } 213 | } 214 | 215 | impl Clone for Data { 216 | fn clone(&self) -> Self { 217 | Data(Arc::clone(&self.0)) 218 | } 219 | } 220 | 221 | #[derive(Clone)] 222 | struct DataMap(Data); 223 | 224 | /// Trait for convenience methods on a Request, which will allow for retrieving Data and params. 225 | pub trait RouterRequestExt { 226 | /// Get data 227 | fn data(&self) -> Option>; 228 | /// Get params 229 | fn params(&self) -> Option<&Params>; 230 | /// Get a specific param 231 | fn get_param(&self, key: &str) -> Option<&str>; 232 | } 233 | 234 | impl RouterRequestExt for crate::Request { 235 | fn data(&self) -> Option> { 236 | self.extensions() 237 | .get::() 238 | .and_then(|x| x.0.get::>()) 239 | .cloned() 240 | } 241 | 242 | fn params(&self) -> Option<&Params> { 243 | self.extensions().get::() 244 | } 245 | 246 | fn get_param(&self, key: &str) -> Option<&str> { 247 | if let Some(params) = self.extensions().get::() { 248 | for (k, v) in params { 249 | // for right now, just returns first. Is this ok? 250 | if key == k { 251 | return Some(v); 252 | } 253 | } 254 | } 255 | None 256 | } 257 | } 258 | 259 | pub(crate) type BoxFuture<'a, T> = Pin + Send + 'a>>; 260 | -------------------------------------------------------------------------------- /src/server/encode.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::borrow_interior_mutable_const)] 2 | // TODO support more than fixed length body 3 | // 4 | // Note: I fixed the encoding ranges on the buffer, and used bytes_read correctly. 5 | // But the final buffer ended up the same? I guess that sending the wrong number of bytes read 6 | // must have mucked up what the stream was reading back out. 7 | 8 | use futures_lite::AsyncRead; 9 | use http::header; 10 | use httpdate::fmt_http_date; 11 | use std::pin::Pin; 12 | use std::task::{Context, Poll}; 13 | use tracing::trace; 14 | 15 | use crate::chunked::ChunkedEncoder; 16 | 17 | use super::response_writer::InnerResponse; 18 | 19 | pub(crate) struct Encoder { 20 | resp: InnerResponse, 21 | state: EncoderState, 22 | 23 | // Tracks bytes read across one Encoder poll_read, which may span 24 | // several calls of encoding methods 25 | bytes_read: usize, 26 | 27 | head_buf: Vec, 28 | head_bytes_read: usize, 29 | 30 | content_length: Option, 31 | body_bytes_read: usize, 32 | 33 | chunked: ChunkedEncoder, 34 | } 35 | 36 | impl Encoder { 37 | pub(crate) fn encode(resp: InnerResponse) -> Self { 38 | let content_length = resp.body.length; 39 | 40 | Self { 41 | resp, 42 | state: EncoderState::Start, 43 | bytes_read: 0, 44 | head_buf: Vec::new(), 45 | head_bytes_read: 0, 46 | content_length, 47 | body_bytes_read: 0, 48 | chunked: ChunkedEncoder::new(), 49 | } 50 | } 51 | 52 | /// At start, prep headers for writing 53 | fn start(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { 54 | let version = self.resp.version; 55 | let status = self.resp.status; 56 | let date = if !self.resp.headers.contains_key(header::DATE) { 57 | Some(fmt_http_date(std::time::SystemTime::now())) 58 | } else { 59 | None 60 | }; 61 | // if there's a body, and if no content-type header, set as application/octet-stream as default 62 | #[allow(clippy::collapsible_if)] 63 | if self.content_length.is_none() || matches!(self.content_length, Some(x) if x > 0) { 64 | if !self.resp.headers.contains_key(header::CONTENT_TYPE) { 65 | self.resp.headers.insert(header::CONTENT_TYPE, "application/octet-stream".parse().unwrap()); 66 | } 67 | } 68 | let headers = self 69 | .resp 70 | .headers 71 | .iter() 72 | .filter(|(h, _)| **h != header::CONTENT_LENGTH) 73 | .filter(|(h, _)| **h != header::TRANSFER_ENCODING); 74 | 75 | std::io::Write::write_fmt( 76 | &mut self.head_buf, 77 | format_args!("{:?} {}\r\n", version, status), 78 | )?; 79 | if let Some(len) = self.content_length { 80 | std::io::Write::write_fmt( 81 | &mut self.head_buf, 82 | format_args!("content-length: {}\r\n", len), 83 | )?; 84 | } else { 85 | std::io::Write::write_fmt( 86 | &mut self.head_buf, 87 | format_args!("transfer-encoding: chunked\r\n"), 88 | )?; 89 | } 90 | if let Some(date) = date { 91 | std::io::Write::write_fmt(&mut self.head_buf, format_args!("date: {}\r\n", date))?; 92 | } 93 | for (header, value) in headers { 94 | // write broken up, because value may contain opaque bytes. 95 | std::io::Write::write_fmt(&mut self.head_buf, format_args!("{}: ", header))?; 96 | std::io::Write::write(&mut self.head_buf, value.as_bytes())?; 97 | std::io::Write::write(&mut self.head_buf, b"\r\n")?; 98 | } 99 | std::io::Write::write_fmt(&mut self.head_buf, format_args!("\r\n"))?; 100 | 101 | // Now everything's prepped, on to sending the header 102 | self.state = EncoderState::Head; 103 | self.encode_head(cx, buf) 104 | } 105 | 106 | fn encode_head( 107 | &mut self, 108 | cx: &mut Context<'_>, 109 | buf: &mut [u8], 110 | ) -> Poll> { 111 | // Each read is not guaranteed to read the entire head_buf. So we keep track of our place 112 | // if the read is partial, so that it can be continued on the next poll. 113 | 114 | // Copy to to buf the shorter of (remaining head_buf) or buf 115 | let len = std::cmp::min(self.head_buf.len() - self.head_bytes_read, buf.len()); 116 | let range = self.head_bytes_read..self.head_bytes_read + len; 117 | buf[0..len].copy_from_slice(&self.head_buf[range]); 118 | self.bytes_read += len; 119 | self.head_bytes_read += len; 120 | 121 | // if entire head_buf is read, continue to body encoding, else keep state and return 122 | // Poll::Ready for this iteration 123 | if self.head_bytes_read == self.head_buf.len() { 124 | match self.content_length { 125 | Some(_) => { 126 | self.state = EncoderState::FixedBody; 127 | self.encode_fixed_body(cx, buf) 128 | } 129 | None => { 130 | self.state = EncoderState::ChunkedBody; 131 | trace!("Server response encoding: chunked body"); 132 | self.encode_chunked_body(cx, buf) 133 | } 134 | } 135 | } else { 136 | Poll::Ready(Ok(self.bytes_read)) 137 | } 138 | } 139 | 140 | fn encode_fixed_body( 141 | &mut self, 142 | cx: &mut Context<'_>, 143 | buf: &mut [u8], 144 | ) -> Poll> { 145 | // Remember that from here, the buf has not been cleared yet, so consider the head as the 146 | // first part of the buf. 147 | 148 | // first check that there's more room in buffer 149 | if self.bytes_read == buf.len() { 150 | return Poll::Ready(Ok(self.bytes_read)); 151 | } 152 | 153 | let content_length = self 154 | .content_length 155 | .expect("content_length.is_some() checked before entering method"); 156 | 157 | // Copy to to buf the shorter of (remaining body + any previous reads) or buf 158 | let upper_limit = std::cmp::min( 159 | self.bytes_read + content_length - self.body_bytes_read, 160 | buf.len(), 161 | ); 162 | let range = self.bytes_read..upper_limit; 163 | let inner_read = Pin::new(&mut self.resp.body).poll_read(cx, &mut buf[range]); 164 | match inner_read { 165 | Poll::Ready(Ok(n)) => { 166 | self.bytes_read += n; 167 | self.body_bytes_read += n; 168 | } 169 | Poll::Ready(Err(err)) => { 170 | return Poll::Ready(Err(err)); 171 | } 172 | Poll::Pending => match self.bytes_read { 173 | 0 => return Poll::Pending, 174 | n => return Poll::Ready(Ok(n)), 175 | }, 176 | } 177 | 178 | // if entire resp is read, finish. Else return Poll::Ready for another iteration 179 | if content_length == self.body_bytes_read { 180 | self.state = EncoderState::Done; 181 | Poll::Ready(Ok(self.bytes_read)) 182 | } else { 183 | self.encode_fixed_body(cx, buf) 184 | } 185 | } 186 | 187 | /// Encode an AsyncBufRead using "chunked" framing. This is used for streams 188 | /// whose length is not known up front. 189 | fn encode_chunked_body( 190 | &mut self, 191 | cx: &mut Context<'_>, 192 | buf: &mut [u8], 193 | ) -> Poll> { 194 | let buf = &mut buf[self.bytes_read..]; 195 | match self.chunked.encode(&mut self.resp.body, cx, buf) { 196 | Poll::Ready(Ok(read)) => { 197 | self.bytes_read += read; 198 | if self.bytes_read == 0 { 199 | self.state = EncoderState::Done 200 | } 201 | Poll::Ready(Ok(self.bytes_read)) 202 | } 203 | Poll::Ready(Err(err)) => Poll::Ready(Err(err)), 204 | Poll::Pending => { 205 | if self.bytes_read > 0 { 206 | return Poll::Ready(Ok(self.bytes_read)); 207 | } 208 | Poll::Pending 209 | } 210 | } 211 | } 212 | } 213 | 214 | impl AsyncRead for Encoder { 215 | fn poll_read( 216 | mut self: Pin<&mut Self>, 217 | cx: &mut Context<'_>, 218 | buf: &mut [u8], 219 | ) -> Poll> { 220 | // bytes_read is per call to poll_read for Encoder 221 | self.bytes_read = 0; 222 | 223 | use EncoderState::*; 224 | match self.state { 225 | Start => self.start(cx, buf), 226 | Head => self.encode_head(cx, buf), 227 | FixedBody => self.encode_fixed_body(cx, buf), 228 | ChunkedBody => self.encode_chunked_body(cx, buf), 229 | Done => Poll::Ready(Ok(0)), 230 | } 231 | } 232 | } 233 | 234 | #[derive(Debug)] 235 | enum EncoderState { 236 | Start, 237 | Head, 238 | FixedBody, 239 | ChunkedBody, 240 | Done, 241 | } 242 | -------------------------------------------------------------------------------- /src/chunked/encoder.rs: -------------------------------------------------------------------------------- 1 | // chunked encoder module is largely from async-h1, with modifications to use http lib 2 | #![allow(clippy::len_zero)] 3 | #![allow(clippy::manual_saturating_arithmetic)] 4 | 5 | use futures_lite::AsyncBufRead; 6 | use std::io; 7 | use std::pin::Pin; 8 | use std::task::{Context, Poll}; 9 | use tracing::trace; 10 | 11 | use crate::body::Body; 12 | 13 | const CR: u8 = b'\r'; 14 | const LF: u8 = b'\n'; 15 | const CRLF_LEN: usize = 2; 16 | 17 | /// The encoder state. 18 | #[derive(Debug)] 19 | enum State { 20 | /// Starting state. 21 | Start, 22 | /// Streaming out chunks. 23 | EncodeChunks, 24 | /// No more chunks to stream, mark the end. 25 | EndOfChunks, 26 | /// Receiving trailers from a channel. 27 | ReceiveTrailers, 28 | /// Streaming out trailers, if we received any. 29 | EncodeTrailers, 30 | /// Writing out the final CRLF. 31 | EndOfStream, 32 | /// The stream has finished. 33 | End, 34 | } 35 | 36 | /// An encoder for chunked encoding. 37 | #[derive(Debug)] 38 | pub(crate) struct ChunkedEncoder { 39 | /// How many bytes we've written to the buffer so far. 40 | bytes_written: usize, 41 | /// The internal encoder state. 42 | state: State, 43 | } 44 | 45 | impl ChunkedEncoder { 46 | /// Create a new instance. 47 | pub(crate) fn new() -> Self { 48 | Self { 49 | state: State::Start, 50 | bytes_written: 0, 51 | } 52 | } 53 | 54 | /// Encode an AsyncBufRead using "chunked" framing. This is used for streams 55 | /// whose length is not known up front. 56 | /// 57 | /// # Format 58 | /// 59 | /// Each "chunk" uses the following encoding: 60 | /// 61 | /// ```txt 62 | /// 1. {byte length of `data` as hex}\r\n 63 | /// 2. {data}\r\n 64 | /// ``` 65 | /// 66 | /// A chunk stream is finalized by appending the following: 67 | /// 68 | /// ```txt 69 | /// 1. 0\r\n 70 | /// 2. {trailing header}\r\n (can be repeated) 71 | /// 3. \r\n 72 | /// ``` 73 | pub(crate) fn encode( 74 | &mut self, 75 | body: &mut Body, 76 | cx: &mut Context<'_>, 77 | buf: &mut [u8], 78 | ) -> Poll> { 79 | self.bytes_written = 0; 80 | match self.state { 81 | State::Start => self.init(body, cx, buf), 82 | State::EncodeChunks => self.encode_chunks(body, cx, buf), 83 | State::EndOfChunks => self.encode_chunks_eos(body, cx, buf), 84 | State::ReceiveTrailers => self.encode_trailers(body, cx, buf), 85 | State::EncodeTrailers => self.encode_trailers(body, cx, buf), 86 | State::EndOfStream => self.encode_eos(cx, buf), 87 | State::End => Poll::Ready(Ok(0)), 88 | } 89 | } 90 | 91 | /// Switch the internal state to a new state. 92 | fn set_state(&mut self, state: State) { 93 | use State::*; 94 | trace!("ChunkedEncoder state: {:?} -> {:?}", self.state, state); 95 | 96 | #[cfg(debug_assertions)] 97 | match self.state { 98 | Start => assert!(matches!(state, EncodeChunks)), 99 | EncodeChunks => assert!(matches!(state, EndOfChunks)), 100 | EndOfChunks => assert!(matches!(state, ReceiveTrailers)), 101 | ReceiveTrailers => assert!(matches!(state, EncodeTrailers | EndOfStream)), 102 | EncodeTrailers => assert!(matches!(state, EndOfStream)), 103 | EndOfStream => assert!(matches!(state, End)), 104 | End => panic!("No state transitions allowed after the stream has ended"), 105 | } 106 | 107 | self.state = state; 108 | } 109 | 110 | /// Init encoding. 111 | fn init( 112 | &mut self, 113 | body: &mut Body, 114 | cx: &mut Context<'_>, 115 | buf: &mut [u8], 116 | ) -> Poll> { 117 | self.set_state(State::EncodeChunks); 118 | self.encode_chunks(body, cx, buf) 119 | } 120 | 121 | /// Stream out data using chunked encoding. 122 | fn encode_chunks( 123 | &mut self, 124 | mut body: &mut Body, 125 | cx: &mut Context<'_>, 126 | buf: &mut [u8], 127 | ) -> Poll> { 128 | // Get bytes from the underlying stream. If the stream is not ready yet, 129 | // return the header bytes if we have any. 130 | let src = match Pin::new(&mut body).poll_fill_buf(cx) { 131 | Poll::Ready(Ok(n)) => n, 132 | Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), 133 | Poll::Pending => match self.bytes_written { 134 | 0 => return Poll::Pending, 135 | n => return Poll::Ready(Ok(n)), 136 | }, 137 | }; 138 | 139 | // If the stream doesn't have any more bytes left to read we're done 140 | // sending chunks and it's time to move on. 141 | if src.len() == 0 { 142 | self.set_state(State::EndOfChunks); 143 | return self.encode_chunks_eos(body, cx, buf); 144 | } 145 | 146 | // Each chunk is prefixed with the length of the data in hex, then a 147 | // CRLF, then the content, then another CRLF. Calculate how many bytes 148 | // each part should be. 149 | let buf_len = buf.len().checked_sub(self.bytes_written).unwrap_or(0); 150 | let msg_len = src.len().min(buf_len); 151 | // Calculate the max char count encoding the `len_prefix` statement 152 | // as hex would take. This is done by rounding up `log16(amt + 1)`. 153 | let hex_len = ((msg_len + 1) as f64).log(16.0).ceil() as usize; 154 | let framing_len = hex_len + CRLF_LEN * 2; 155 | let buf_upper = buf_len.checked_sub(framing_len).unwrap_or(0); 156 | let msg_len = msg_len.min(buf_upper); 157 | let len_prefix = format!("{:X}", msg_len).into_bytes(); 158 | 159 | // Request a new buf if the current buf is too small to write any data 160 | // into. Empty frames should only be sent to mark the end of a stream. 161 | if buf.len() <= framing_len { 162 | cx.waker().wake_by_ref(); 163 | return Poll::Ready(Ok(self.bytes_written)); 164 | } 165 | 166 | // Write our frame header to the buffer. 167 | let lower = self.bytes_written; 168 | let upper = self.bytes_written + len_prefix.len(); 169 | buf[lower..upper].copy_from_slice(&len_prefix); 170 | buf[upper] = CR; 171 | buf[upper + 1] = LF; 172 | self.bytes_written += len_prefix.len() + 2; 173 | 174 | // Copy the bytes from our source into the output buffer. 175 | let lower = self.bytes_written; 176 | let upper = self.bytes_written + msg_len; 177 | buf[lower..upper].copy_from_slice(&src[0..msg_len]); 178 | Pin::new(&mut body).consume(msg_len); 179 | self.bytes_written += msg_len; 180 | 181 | // Finalize the chunk with a closing CRLF. 182 | let idx = self.bytes_written; 183 | buf[idx] = CR; 184 | buf[idx + 1] = LF; 185 | self.bytes_written += CRLF_LEN; 186 | 187 | // Finally return how many bytes we've written to the buffer. 188 | trace!("sending {} bytes", self.bytes_written); 189 | Poll::Ready(Ok(self.bytes_written)) 190 | } 191 | 192 | fn encode_chunks_eos( 193 | &mut self, 194 | body: &mut Body, 195 | cx: &mut Context<'_>, 196 | buf: &mut [u8], 197 | ) -> Poll> { 198 | // Request a new buf if the current buf is too small to write into. 199 | if buf.len() < 3 { 200 | cx.waker().wake_by_ref(); 201 | return Poll::Ready(Ok(self.bytes_written)); 202 | } 203 | 204 | // Write out the final empty chunk 205 | let idx = self.bytes_written; 206 | buf[idx] = b'0'; 207 | buf[idx + 1] = CR; 208 | buf[idx + 2] = LF; 209 | self.bytes_written += 1 + CRLF_LEN; 210 | 211 | self.set_state(State::ReceiveTrailers); 212 | self.receive_trailers(body, cx, buf) 213 | } 214 | 215 | /// Receive trailers sent to the response, and store them in an internal 216 | /// buffer. 217 | fn receive_trailers( 218 | &mut self, 219 | body: &mut Body, 220 | cx: &mut Context<'_>, 221 | buf: &mut [u8], 222 | ) -> Poll> { 223 | // TODO: actually wait for trailers to be received. 224 | self.set_state(State::EncodeTrailers); 225 | self.encode_trailers(body, cx, buf) 226 | } 227 | 228 | /// Send trailers to the buffer. 229 | fn encode_trailers( 230 | &mut self, 231 | _body: &mut Body, 232 | cx: &mut Context<'_>, 233 | buf: &mut [u8], 234 | ) -> Poll> { 235 | // TODO: actually encode trailers here. 236 | self.set_state(State::EndOfStream); 237 | self.encode_eos(cx, buf) 238 | } 239 | 240 | /// Encode the end of the stream. 241 | fn encode_eos(&mut self, _cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { 242 | let idx = self.bytes_written; 243 | // Write the final CRLF 244 | buf[idx] = CR; 245 | buf[idx + 1] = LF; 246 | self.bytes_written += CRLF_LEN; 247 | 248 | self.set_state(State::End); 249 | Poll::Ready(Ok(self.bytes_written)) 250 | } 251 | } 252 | -------------------------------------------------------------------------------- /tests/mock.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] // because of testing with and without anyhow errors 2 | 3 | //! Test Client for testing server 4 | //! Test Server for testing client 5 | 6 | use futures_lite::{AsyncBufRead, AsyncRead, AsyncWrite}; 7 | use std::io; 8 | use std::pin::Pin; 9 | use std::sync::{Arc, Mutex}; 10 | use std::task::{Context, Poll}; 11 | 12 | #[derive(Clone)] 13 | pub struct Client { 14 | // bool is true if read/written, fallse if not yet read/written 15 | // TODO make rdr and wtr structs so this is easier to understand. 16 | read_buf: Arc, bool)>>, 17 | write_buf: Arc, usize)>>, 18 | expected: Vec, 19 | // sometimes writer needs to write more than once, like for chunks 20 | num_writes: usize, 21 | } 22 | 23 | impl Client { 24 | pub fn new(req: &str, expected_resp: &str) -> Self { 25 | Self { 26 | read_buf: Arc::new(Mutex::new((req.to_owned().into_bytes(), false))), 27 | write_buf: Arc::new(Mutex::new((Vec::new(), 0))), 28 | expected: expected_resp.to_owned().into_bytes(), 29 | num_writes: 1, 30 | } 31 | } 32 | 33 | pub fn new_with_writes(req: &str, expected_resp: &str, writes: usize) -> Self { 34 | Self { 35 | read_buf: Arc::new(Mutex::new((req.to_owned().into_bytes(), false))), 36 | write_buf: Arc::new(Mutex::new((Vec::new(), 0))), 37 | expected: expected_resp.to_owned().into_bytes(), 38 | num_writes: writes, 39 | } 40 | } 41 | 42 | pub fn assert(self) { 43 | let write_buf = self.write_buf.lock().unwrap(); 44 | let resp = remove_date(&write_buf.0); 45 | assert_eq!( 46 | String::from_utf8(resp).unwrap(), 47 | String::from_utf8(self.expected).unwrap() 48 | ); 49 | } 50 | 51 | pub fn assert_with_resp_date(self, date: &str) { 52 | let write_buf = self.write_buf.lock().unwrap(); 53 | 54 | let resp_with_date = String::from_utf8(write_buf.0.clone()).unwrap(); 55 | resp_with_date.find(date).unwrap(); 56 | 57 | let resp = remove_date(&write_buf.0); 58 | assert_eq!( 59 | String::from_utf8(resp).unwrap(), 60 | String::from_utf8(self.expected).unwrap() 61 | ); 62 | } 63 | } 64 | 65 | impl AsyncRead for Client { 66 | fn poll_read( 67 | self: Pin<&mut Self>, 68 | _cx: &mut Context, 69 | buf: &mut [u8], 70 | ) -> Poll> { 71 | let mut rdr = self.read_buf.lock().unwrap(); 72 | if !rdr.1 { 73 | rdr.1 = true; 74 | io::Read::read(&mut io::Cursor::new(&*rdr.0), buf).unwrap(); 75 | Poll::Ready(Ok(rdr.0.len())) 76 | } else { 77 | Poll::Ready(Ok(0)) 78 | } 79 | } 80 | } 81 | 82 | impl AsyncWrite for Client { 83 | fn poll_write(self: Pin<&mut Self>, _cx: &mut Context, buf: &[u8]) -> Poll> { 84 | let mut wtr = self.write_buf.lock().unwrap(); 85 | if wtr.1 < self.num_writes { 86 | wtr.1 += 1; 87 | wtr.0.extend_from_slice(buf); 88 | Poll::Ready(Ok(wtr.0.len())) 89 | } else { 90 | Poll::Ready(Ok(0)) 91 | } 92 | } 93 | 94 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { 95 | Poll::Ready(Ok(())) // placeholder, shouldn't hit? 96 | } 97 | 98 | fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { 99 | Poll::Ready(Ok(())) // placeholder, shouldn't hit? 100 | } 101 | } 102 | 103 | // TODO refactor to combine mock client and server? 104 | // TODO should reading or writing be read or write many times? 105 | #[derive(Clone)] 106 | pub struct Server { 107 | // bool is true if read/written, fallse if not yet read/written 108 | // TODO make rdr and wtr structs so this is easier to understand. 109 | read_buf: Arc, bool)>>, 110 | write_buf: Arc, usize)>>, 111 | expected: Vec, 112 | // sometimes writer needs to write more than once, like for chunks 113 | num_writes: usize, 114 | } 115 | 116 | impl Server { 117 | pub fn new(expected_req: &str, resp: &str) -> Self { 118 | Self { 119 | read_buf: Arc::new(Mutex::new((resp.to_owned().into_bytes(), false))), 120 | write_buf: Arc::new(Mutex::new((Vec::new(), 0))), 121 | expected: expected_req.to_owned().into_bytes(), 122 | num_writes: 1, 123 | } 124 | } 125 | 126 | pub fn new_with_writes(expected_req: &str, resp: &str, writes: usize) -> Self { 127 | Self { 128 | read_buf: Arc::new(Mutex::new((resp.to_owned().into_bytes(), false))), 129 | write_buf: Arc::new(Mutex::new((Vec::new(), 0))), 130 | expected: expected_req.to_owned().into_bytes(), 131 | num_writes: writes, 132 | } 133 | } 134 | 135 | pub fn assert(self) { 136 | let write_buf = self.write_buf.lock().unwrap(); 137 | let req = remove_date(&write_buf.0); 138 | assert_eq!( 139 | String::from_utf8(req).unwrap(), 140 | String::from_utf8(self.expected).unwrap() 141 | ); 142 | } 143 | 144 | pub fn assert_with_resp_date(self, date: &str) { 145 | let write_buf = self.write_buf.lock().unwrap(); 146 | 147 | let req_with_date = String::from_utf8(write_buf.0.clone()).unwrap(); 148 | req_with_date.find(date).unwrap(); 149 | 150 | let req = remove_date(&write_buf.0); 151 | assert_eq!( 152 | String::from_utf8(req).unwrap(), 153 | String::from_utf8(self.expected).unwrap() 154 | ); 155 | } 156 | } 157 | 158 | impl AsyncRead for Server { 159 | fn poll_read( 160 | self: Pin<&mut Self>, 161 | _cx: &mut Context, 162 | buf: &mut [u8], 163 | ) -> Poll> { 164 | let mut rdr = self.read_buf.lock().unwrap(); 165 | if !rdr.1 { 166 | println!("hit mock server read, sending: {}", String::from_utf8(rdr.0.clone()).unwrap()); 167 | rdr.1 = true; 168 | io::Read::read(&mut io::Cursor::new(&*rdr.0), buf).unwrap(); 169 | Poll::Ready(Ok(rdr.0.len())) 170 | } else { 171 | Poll::Ready(Ok(0)) 172 | } 173 | } 174 | } 175 | 176 | impl AsyncWrite for Server { 177 | fn poll_write(self: Pin<&mut Self>, _cx: &mut Context, buf: &[u8]) -> Poll> { 178 | let mut wtr = self.write_buf.lock().unwrap(); 179 | if wtr.1 < self.num_writes { 180 | wtr.1 += 1; 181 | wtr.0.extend_from_slice(buf); 182 | Poll::Ready(Ok(wtr.0.len())) 183 | } else { 184 | Poll::Ready(Ok(0)) 185 | } 186 | } 187 | 188 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { 189 | Poll::Ready(Ok(())) // placeholder, shouldn't hit? 190 | } 191 | 192 | fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { 193 | Poll::Ready(Ok(())) // placeholder, shouldn't hit? 194 | } 195 | } 196 | 197 | // just strip date from response 198 | fn remove_date(b: &[u8]) -> Vec { 199 | // just change to str and back is easier for now 200 | let s = std::str::from_utf8(b).unwrap(); 201 | if let Some(i) = s.find("date: ") { 202 | let eol = s[i + 6..].find("\r\n").expect("missing date eol"); 203 | let mut res = Vec::new(); 204 | res.extend_from_slice(&b[..i]); 205 | res.extend_from_slice(&b[i + 6 + eol + 2..]); 206 | res 207 | } else { 208 | b.to_vec() 209 | } 210 | } 211 | 212 | pub(crate) struct Cursor { 213 | inner: std::io::Cursor, 214 | } 215 | 216 | impl Cursor { 217 | #[allow(dead_code)] 218 | pub(crate) fn new(t: T) -> Self { 219 | Self { 220 | inner: std::io::Cursor::new(t), 221 | } 222 | } 223 | } 224 | 225 | impl AsyncRead for Cursor 226 | where 227 | T: AsRef<[u8]> + Unpin, 228 | { 229 | fn poll_read( 230 | mut self: Pin<&mut Self>, 231 | _cx: &mut Context<'_>, 232 | buf: &mut [u8], 233 | ) -> Poll> { 234 | Poll::Ready(std::io::Read::read(&mut self.inner, buf)) 235 | } 236 | 237 | fn poll_read_vectored( 238 | mut self: Pin<&mut Self>, 239 | _cx: &mut Context<'_>, 240 | bufs: &mut [std::io::IoSliceMut<'_>], 241 | ) -> Poll> { 242 | Poll::Ready(std::io::Read::read_vectored(&mut self.inner, bufs)) 243 | } 244 | } 245 | 246 | impl AsyncBufRead for Cursor 247 | where 248 | T: AsRef<[u8]> + Unpin, 249 | { 250 | fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { 251 | Poll::Ready(std::io::BufRead::fill_buf(&mut self.get_mut().inner)) 252 | } 253 | 254 | fn consume(mut self: Pin<&mut Self>, amt: usize) { 255 | std::io::BufRead::consume(&mut self.inner, amt) 256 | } 257 | } 258 | 259 | #[cfg(test)] 260 | mod test { 261 | use super::*; 262 | 263 | #[test] 264 | fn test_remove_date() { 265 | let input = 266 | b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\ndate: Thu, 07 May 2020 15:54:21 GMT\r\n\r\n"; 267 | let expected = b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n"; 268 | 269 | assert_eq!( 270 | String::from_utf8(remove_date(input)), 271 | String::from_utf8(expected.to_vec()) 272 | ); 273 | } 274 | } 275 | -------------------------------------------------------------------------------- /src/server/response_writer.rs: -------------------------------------------------------------------------------- 1 | use futures_lite::{io, AsyncWrite}; 2 | use futures_util::TryStreamExt; 3 | use http::{ 4 | header::{HeaderMap, HeaderValue, IntoHeaderName}, 5 | status::StatusCode, 6 | version::Version, 7 | }; 8 | use tracing::error; 9 | 10 | use crate::body::Body; 11 | use crate::response::Response; 12 | 13 | use super::encode::Encoder; 14 | use super::glitch::Glitch; 15 | 16 | pin_project_lite::pin_project! { 17 | pub(crate) struct InnerResponse { 18 | pub(crate) status: StatusCode, 19 | pub(crate) headers: HeaderMap, 20 | //url: Url, // TODO what is this for? 21 | pub(crate) version: Version, 22 | //pub(crate) extensions: Extensions, // TODO do I need this? 23 | #[pin] 24 | pub(crate)body: Body, 25 | } 26 | } 27 | 28 | impl InnerResponse { 29 | /// used for bad request in decoding. 400 30 | pub(crate) fn bad_request() -> Self { 31 | Self { 32 | status: StatusCode::BAD_REQUEST, 33 | headers: HeaderMap::new(), 34 | version: Version::default(), 35 | body: Body::empty(), 36 | } 37 | } 38 | 39 | /// used for version not supported in decoding. 505 40 | pub(crate) fn version_not_supported() -> Self { 41 | Self { 42 | status: StatusCode::HTTP_VERSION_NOT_SUPPORTED, 43 | headers: HeaderMap::new(), 44 | version: Version::default(), 45 | body: Body::empty(), 46 | } 47 | } 48 | 49 | /// used for unimplemented transfer-encoding in decoding. 501 50 | pub(crate) fn not_implemented() -> Self { 51 | Self { 52 | status: StatusCode::NOT_IMPLEMENTED, 53 | headers: HeaderMap::new(), 54 | version: Version::default(), 55 | body: Body::empty(), 56 | } 57 | } 58 | 59 | pub(crate) async fn send(self, writer: W) -> Result 60 | where 61 | W: AsyncWrite + Clone + Send + Sync + Unpin + 'static, 62 | { 63 | let mut encoder = Encoder::encode(self); 64 | let mut writer = writer; 65 | let bytes_written = match io::copy(&mut encoder, &mut writer).await { 66 | Ok(b) => b, 67 | Err(err) => { 68 | // only log, don't break connection here. If connection is really closed, then the 69 | // next decode will break the loop receiving requests 70 | error!("Error sending response: {}", err); 71 | return Err(err); 72 | } 73 | }; 74 | 75 | Ok(ResponseWritten { bytes_written }) 76 | } 77 | } 78 | 79 | /// `ResponseWriter` has two responsibilities: 80 | /// - Hold a `Response` which can be modified or replaced. 81 | /// - Expose a `send` method which will immediately write the Response to the Http connection. 82 | /// 83 | /// A `ResponseWriter` is initialized with a `Response` that contains: 84 | /// - An empty body 85 | /// - No headers (except that content-type defaults to `application/octet-stream` if not specified 86 | /// and there's a body)` 87 | /// - A 200 OK status 88 | /// 89 | /// You can modify the `Response` as they see fit. Note, however, that a `Body` is not 90 | /// necessarily in sync with the `content-type` headers that are sent. for example, it's possible 91 | /// to set the Body using a string, and then set the content-type header on the Response to be 92 | /// `content-type: video/mp4'. The power is in the your hands. 93 | /// 94 | /// There are two convenience methods which will set the content-type: 95 | /// - `set_text`, because there's no guess as to content-type, and 96 | /// - `set_sse`, because the content-type `text/event-stream` is required. 97 | /// 98 | /// If you wish to create a `Response` separately and then apply it to the `ResponseWriter`, you can 99 | /// use `tophat::http::Response` and `tophat::Body`, and then `ReponseWriter::response_mut`. 100 | /// 101 | /// All methods on `ResponseWriter` should list what headers they modify in the document string, and 102 | /// the type of the parameter should be reflected in the function name (i.e. `text` takes a string, 103 | /// not a stream or reader). 104 | /// 105 | /// Possible body types: 106 | /// - &str/String, 107 | /// - AsyncRead, 108 | /// - Stream (StreamExt), 109 | pub struct ResponseWriter 110 | where 111 | W: AsyncWrite + Clone + Send + Sync + Unpin + 'static, 112 | { 113 | pub(crate) response: Response, 114 | pub(crate) writer: W, 115 | } 116 | 117 | impl ResponseWriter 118 | where 119 | W: AsyncWrite + Clone + Send + Sync + Unpin + 'static, 120 | { 121 | /// send response, and return number of bytes written 122 | pub async fn send(self) -> Result { 123 | let (parts, body) = self.response.into_parts(); 124 | 125 | let inner_resp = InnerResponse { 126 | status: parts.status, 127 | headers: parts.headers, 128 | version: parts.version, 129 | body, 130 | }; 131 | 132 | Ok(inner_resp.send(self.writer).await?) 133 | } 134 | 135 | /// Sets response to specified code and immediately sends. 136 | /// 137 | /// Devised as a shortcut so it would be easier to send a response with an empty body and 138 | /// status code. But if body is present, it will send that. (There's no effect on anything 139 | /// besides the status code) 140 | /// 141 | /// Internally panics if status code is incorrect (use at your own risk! For something safer, 142 | /// try `set_status`. 143 | pub async fn send_code(self, code: u16) -> Result { 144 | let mut this = self; 145 | this.set_code(code); 146 | 147 | this.send().await 148 | } 149 | 150 | /// Set response to specified status_code. 151 | pub fn set_status(&mut self, status: http::StatusCode) -> &mut Self { 152 | *self.response.status_mut() = status; 153 | self 154 | } 155 | 156 | /// Set response to specified code. 157 | /// 158 | /// Internally panics if code is incorrect (use at your own risk! For something safer, try 159 | /// `set_status`. 160 | pub fn set_code(&mut self, code: u16) -> &mut Self { 161 | *self.response.status_mut() = http::StatusCode::from_u16(code).unwrap(); 162 | self 163 | } 164 | 165 | /// Set response to specified body. 166 | /// 167 | /// Does not change content-type, that must be set separately in headers. 168 | pub fn set_body(&mut self, body: Body) -> &mut Self { 169 | *self.response.body_mut() = body; 170 | self 171 | } 172 | 173 | /// Append header to response. Will not replace a header with the same header name. 174 | pub fn append_header( 175 | &mut self, 176 | header_name: impl IntoHeaderName, 177 | header_value: HeaderValue, 178 | ) -> &mut Self { 179 | self.response 180 | .headers_mut() 181 | .append(header_name, header_value); 182 | self 183 | } 184 | 185 | /// Insert header to response. Replaces a header with the same header name. 186 | pub fn insert_header( 187 | &mut self, 188 | header_name: impl IntoHeaderName, 189 | header_value: HeaderValue, 190 | ) -> &mut Self { 191 | self.response 192 | .headers_mut() 193 | .insert(header_name, header_value); 194 | self 195 | } 196 | 197 | /// Mutable access to the full response. This way, if you like you can create the `Response` 198 | /// separately, and then set it in the `ResponseWriter` 199 | /// ```rust 200 | /// # use futures_util::io::{AsyncRead, AsyncWrite}; 201 | /// # use std::error::Error; 202 | /// # use tophat::{Body, Request, Response, server::{glitch::Result, ResponseWriter, ResponseWritten}}; 203 | /// async fn handler(req: Request, mut resp_wtr: ResponseWriter) -> Result 204 | /// where W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 205 | /// { 206 | /// let resp = Response::new(Body::empty()); 207 | /// *resp_wtr.response_mut() = resp; 208 | /// resp_wtr.send().await 209 | /// } 210 | /// ``` 211 | pub fn response_mut(&mut self) -> &mut Response { 212 | &mut self.response 213 | } 214 | 215 | /// Retrieve a reference to the `Response` in the `ResponseWriter` 216 | pub fn response(&self) -> &Response { 217 | &self.response 218 | } 219 | 220 | /// Set response to: 221 | /// - 200 OK 222 | /// - Content-type text/plain 223 | /// - Body from String 224 | /// 225 | pub fn set_text(&mut self, text: String) -> &mut Self { 226 | *self.response.body_mut() = text.into(); 227 | self.response 228 | .headers_mut() 229 | .insert(http::header::CONTENT_TYPE, "text/plain".parse().unwrap()); 230 | self 231 | } 232 | 233 | /// Sets the response body as a Server Sent Events response stream. 234 | /// Adds the content-type header for SSE. 235 | /// 236 | /// Takes a `futures::Stream`, and `futures::TryStreamExt` must be in scope. 237 | pub fn set_sse(&mut self, stream: S) 238 | where 239 | S: TryStreamExt + Send + Sync + Unpin + 'static, 240 | S::Ok: AsRef<[u8]> + Send + Sync, 241 | { 242 | let stream = stream.into_async_read(); 243 | 244 | self.set_body(Body::from_reader(stream, None)); 245 | self.insert_header(http::header::CONTENT_TYPE, "text/event-stream".parse().unwrap()); 246 | } 247 | } 248 | 249 | /// A marker to ensure that a response is written inside a request handler. 250 | pub struct ResponseWritten { 251 | bytes_written: u64, 252 | } 253 | 254 | impl ResponseWritten { 255 | /// Bytes written by `ResponseWriter` 256 | pub fn bytes_written(&self) -> u64 { 257 | self.bytes_written 258 | } 259 | } 260 | -------------------------------------------------------------------------------- /src/server/identity.rs: -------------------------------------------------------------------------------- 1 | //! bare-bones Identity service 2 | //! 3 | //! Not middleware :) 4 | //! 5 | //! The service is kept in the global state (Data in the router) 6 | //! 7 | //! Only manually verified/tested, use at own risk. 8 | //! Currently has several `unwrap` which may panic. 9 | //! 10 | //! Cookies only, using jwt tokens. No custom claims. 11 | //! 12 | //! It's a bit manual, but you'll have to: 13 | //! 14 | //! - set jwt token on Response `identity.set_authorization(res)` 15 | //! - check authentication on Request `identity.authorized_user(req)` 16 | //! - forget (clear jwt token, basically sets a cookie with no name and no duration) 17 | //! `identity.forget(res)` 18 | 19 | use cookie::Cookie; 20 | use futures_util::io::{AsyncRead, AsyncWrite}; 21 | use http::header; 22 | use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; 23 | use serde::{Deserialize, Serialize}; 24 | use std::convert::TryInto; 25 | use std::fmt; 26 | use std::time::Duration; 27 | 28 | use crate::{server::ResponseWriter, Request}; 29 | 30 | /// Identity "middlware", for handling authorized sessions. 31 | #[derive(Clone)] 32 | pub struct Identity { 33 | /// The key for signing jwts. Should be kept private, but needs 34 | /// to be the same on multiple servers sharing a jwt domain. 35 | server_key: String, 36 | /// Value for the iss (issuer) jwt claim. 37 | issuer: Option, 38 | /// How long a token should be valid after creation, in seconds 39 | expiration_time: Duration, 40 | /// Cookie name (Currently only cookies supported, no Auth header). 41 | /// Default "jwt" 42 | cookie_name: String, 43 | /// Cookie path 44 | /// Default "/" 45 | /// TODO offer more granular path setting? 46 | cookie_path: String, 47 | /// Cookie secure 48 | /// Default true 49 | cookie_secure: bool, 50 | /// Cookie Http Only 51 | /// Default true 52 | cookie_http_only: bool, 53 | } 54 | 55 | impl Identity { 56 | /// Create a new instance. 57 | /// 58 | /// The `server_key` is used for signing and validating the jwt token. 59 | pub fn build(server_key: &str) -> IdentityBuilder { 60 | IdentityBuilder::new(server_key) 61 | } 62 | 63 | /// Checked for an authorized user for the incoming request 64 | pub fn authorized_user(&self, req: &Request) -> Option { 65 | // Get Cookie and token 66 | let jwtstr = get_cookie(&req, &self.cookie_name); 67 | 68 | // Decode token 69 | if let Some(jwtstr) = jwtstr { 70 | let token = decode::( 71 | &jwtstr, 72 | &DecodingKey::from_secret(self.server_key.as_bytes()), 73 | &Validation::default(), 74 | ) 75 | .ok()?; 76 | 77 | //println!("{:?}", token); 78 | Some(token.claims.sub) 79 | } else { 80 | None 81 | } 82 | } 83 | 84 | /// Set a token on the `ResponseWriter`, which gets set in a cookie, which authorizes the user. 85 | pub fn set_auth_token(&self, user: &str, resp_wtr: &mut ResponseWriter) 86 | where 87 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 88 | { 89 | // in header set_cookie and provide token 90 | // 91 | // This should never fail 92 | let token = self.make_token(Some(user), None).unwrap(); 93 | let cookie = Cookie::build(&self.cookie_name, token) 94 | .path(&self.cookie_path) 95 | .max_age(self.expiration_time.try_into().unwrap()) // this uses time crate :( 96 | .http_only(self.cookie_http_only) 97 | .secure(self.cookie_secure) 98 | .finish(); 99 | resp_wtr.append_header(header::SET_COOKIE, cookie.to_string().parse().unwrap()); 100 | } 101 | 102 | /// Set an expired token on the `ResponseWriter`, which gets set in a cookie, which will 103 | /// effectively "log out" the user. 104 | pub fn forget(&self, resp_wtr: &mut ResponseWriter) 105 | where 106 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 107 | { 108 | // in header set_cookie and provide "blank" token 109 | // 110 | // This should never fail 111 | let token = self.make_token(None, Some(0)).unwrap(); 112 | let cookie = Cookie::build(&self.cookie_name, token) 113 | .path(&self.cookie_path) 114 | .max_age(time::Duration::seconds(0)) // this uses time crate :( 115 | .http_only(self.cookie_http_only) 116 | .secure(self.cookie_secure) 117 | .finish(); 118 | resp_wtr.append_header(header::SET_COOKIE, cookie.to_string().parse().unwrap()); 119 | } 120 | 121 | fn make_token( 122 | &self, 123 | user: Option<&str>, 124 | expiration: Option, 125 | ) -> Result { 126 | let claims = Claims { 127 | exp: expiration 128 | .unwrap_or_else(|| self.expiration_time.as_secs() + current_numeric_date()), 129 | iss: self 130 | .issuer 131 | .as_ref() 132 | .cloned() 133 | .unwrap_or_else(|| "".to_owned()), 134 | sub: user.map(|s| s.to_owned()).unwrap_or_else(|| "".to_owned()), 135 | }; 136 | 137 | encode( 138 | &Header::default(), 139 | &claims, 140 | &EncodingKey::from_secret(self.server_key.as_bytes()), 141 | ) 142 | .map_err(IdentityFail::Encode) 143 | } 144 | } 145 | 146 | // Separate builder, because there's two sets of apis, one for building and one for using. 147 | // 148 | // If it was just build and then finish, might not need a builder. 149 | /// Builder for Identity 150 | pub struct IdentityBuilder { 151 | server_key: String, 152 | issuer: Option, 153 | expiration_time: Duration, 154 | cookie_name: Option, // default "jwt" 155 | cookie_path: Option, // default "/" 156 | cookie_secure: bool, // default true 157 | cookie_http_only: bool, // default true 158 | } 159 | 160 | impl IdentityBuilder { 161 | /// Create a new instance. 162 | /// 163 | /// The `server_key` is used for signing and validating the jwt token. 164 | pub fn new(server_key: &str) -> IdentityBuilder { 165 | IdentityBuilder { 166 | server_key: server_key.to_owned(), 167 | issuer: None, 168 | expiration_time: Duration::from_secs(60 * 60 * 24), 169 | cookie_name: None, 170 | cookie_path: None, 171 | cookie_secure: true, 172 | cookie_http_only: true, 173 | } 174 | } 175 | 176 | /// Set a value for the iss (issuer) jwt claim. 177 | /// 178 | /// The default is to not set an issuer. 179 | pub fn cookie_name(mut self, name: &str) -> Self { 180 | self.cookie_name = Some(name.to_owned()); 181 | self 182 | } 183 | /// Set cookie path 184 | /// 185 | /// The default is "/". 186 | pub fn cookie_path(mut self, path: &str) -> Self { 187 | self.cookie_path = Some(path.to_owned()); 188 | self 189 | } 190 | 191 | /// Set cookie Secure (https only) 192 | /// 193 | /// The default is true. 194 | pub fn cookie_secure(mut self, secure: bool) -> Self { 195 | self.cookie_secure = secure; 196 | self 197 | } 198 | 199 | /// Set cookie http only 200 | /// 201 | /// The default is true. 202 | pub fn cookie_http_only(mut self, http_only: bool) -> Self { 203 | self.cookie_http_only = http_only; 204 | self 205 | } 206 | 207 | /// Set a value for the iss (issuer) jwt claim. 208 | /// 209 | /// The default is to not set an issuer. 210 | pub fn issuer(mut self, issuer: &str) -> Self { 211 | self.issuer = Some(issuer.to_owned()); 212 | self 213 | } 214 | 215 | /// Set how long a token should be valid after creation (in seconds). 216 | /// 217 | /// The default is 24 hours. 218 | pub fn expiration_time(mut self, expiration_time: Duration) -> Self { 219 | self.expiration_time = expiration_time; 220 | self 221 | } 222 | 223 | /// Finish building an Identity 224 | pub fn finish(self) -> Identity { 225 | Identity { 226 | server_key: self.server_key, 227 | issuer: self.issuer, 228 | expiration_time: self.expiration_time, 229 | cookie_name: self.cookie_name.unwrap_or_else(|| "jwt".to_owned()), 230 | cookie_path: self.cookie_path.unwrap_or_else(|| "/".to_owned()), 231 | cookie_secure: self.cookie_secure, 232 | cookie_http_only: self.cookie_http_only, 233 | } 234 | } 235 | } 236 | 237 | /// Gets the first cookie with the name 238 | fn get_cookie(req: &Request, name: &str) -> Option { 239 | for cookie in req.headers().get_all(header::COOKIE) { 240 | let cookie = Cookie::parse(cookie.to_str().ok()?).ok()?; 241 | if cookie.name() == name { 242 | return Some(cookie.value().to_string()); 243 | } 244 | } 245 | None 246 | } 247 | 248 | /// Get the current value for jwt NumericDate. 249 | /// 250 | /// Defined in RFC 7519 section 2 to be equivalent to POSIX.1 "Seconds 251 | /// Since the Epoch". The RFC allows a NumericDate to be non-integer 252 | /// (for sub-second resolution), but the jwt crate uses u64. 253 | fn current_numeric_date() -> u64 { 254 | use std::time::{SystemTime, UNIX_EPOCH}; 255 | SystemTime::now() 256 | .duration_since(UNIX_EPOCH) 257 | .ok() 258 | .unwrap() 259 | .as_secs() 260 | } 261 | 262 | // Claims to token 263 | #[derive(Debug, Serialize, Deserialize)] 264 | struct Claims { 265 | exp: u64, 266 | iss: String, 267 | // user 268 | sub: String, 269 | } 270 | 271 | /// Error for Identity. Bascially, the errors are for encoding or decoding the jwt token. 272 | #[derive(Debug)] 273 | pub enum IdentityFail { 274 | /// Encode error for jwt token 275 | Encode(jsonwebtoken::errors::Error), 276 | /// Decode error for jwt token 277 | Decode(jsonwebtoken::errors::Error), 278 | } 279 | 280 | impl std::error::Error for IdentityFail {} 281 | 282 | impl fmt::Display for IdentityFail { 283 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 284 | use IdentityFail::*; 285 | match self { 286 | Encode(err) => write!(f, "jwt encoding error: {}", err), 287 | Decode(err) => write!(f, "jwt decoding error: {}", err), 288 | } 289 | } 290 | } 291 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2020] [tophat developers] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/server/glitch.rs: -------------------------------------------------------------------------------- 1 | //! Response type for handling Endpoint errors. 2 | //! 3 | //! ## Overview 4 | //! Errors returned by users within an Endpoint are never meant to be bubbled up all the way to the 5 | //! server to handle; instead, they should be caught immediately after the endpoint, where they are 6 | //! transformed into a Response. 7 | //! 8 | //! This means that a Glitch is not an Error in the general Rust sense, it's very 9 | //! Response-specific. 10 | //! 11 | //! Without this functionality, a user will always have to create their own error responses and 12 | //! manually return them, without the convenience of Rust's built-in `Result` and `?` operator. 13 | //! 14 | //! For cases when you want to have more control over an "error" response, it's suggested to just 15 | //! build a normal response with `ResponseWriter`. For example, a 500 on database disconnection 16 | //! will work well with a Glitch, but informing a user that they've exceeded their x number of 17 | //! allotted api requests might use `ResponseWriter` to craft a json body. 18 | //! 19 | //! ## Functionality 20 | //! A `Glitch` allows you to: 21 | //! - Just use `?` on any error, and it will be turned into a 500 response. (`anyhow` feature 22 | //! only) 23 | //! - use `.map_err` to easily convert your error to a Glitch. 24 | //! 25 | //! In this system, it's easy to use standard `From` and `Into` traits to convert your custom 26 | //! errors if you want. 27 | // I think that this is unlike warp, which requires you to match on your error in a `catch`, and 28 | // then convert your error to a response then? Here, your error is converted on the spot. 29 | //! 30 | 31 | use http::{header::HeaderMap, status::StatusCode, version::Version}; 32 | use std::convert::Infallible; 33 | use std::fmt::Display; 34 | 35 | use crate::server::InnerResponse; 36 | 37 | /// Convenience type for `Result` 38 | pub type Result = std::result::Result; 39 | 40 | // similar to inner_response, but with string-only body 41 | /// Glitch is designed to be the error response for tophat. Users can either create them manually, 42 | /// or use `GlitchExt` to easily convert from `std::error::Error`. 43 | /// 44 | /// Note that if you create a message or error, they will be converted to a message string, and the 45 | /// content-type will be set to `text/plain`. For more control over the response, use 46 | /// `ResponseWriter`. 47 | #[derive(Debug)] 48 | pub struct Glitch { 49 | pub(crate) status: Option, 50 | pub(crate) headers: Option, 51 | pub(crate) version: Option, 52 | pub(crate) message: Option, 53 | 54 | // keep things simple, this is just response so no need to hold an actual error. Just print the 55 | // error string. 56 | pub(crate) trace: Option, 57 | } 58 | 59 | impl From for Glitch 60 | where 61 | E: std::error::Error + Send + Sync + 'static, 62 | { 63 | fn from(error: E) -> Self { 64 | Self::new_with_err(error) 65 | } 66 | } 67 | 68 | impl std::default::Default for Glitch { 69 | fn default() -> Self { 70 | Self { 71 | status: None, 72 | headers: None, 73 | version: None, 74 | message: None, 75 | trace: None, 76 | } 77 | } 78 | } 79 | 80 | impl Glitch { 81 | /// Create a Glitch 82 | pub fn new() -> Self { 83 | Self { 84 | status: None, 85 | headers: None, 86 | version: None, 87 | message: None, 88 | trace: None, 89 | } 90 | } 91 | 92 | pub(crate) fn new_with_err(error: E) -> Self 93 | where 94 | E: std::error::Error + Send + Sync + 'static, 95 | { 96 | Self { 97 | status: None, 98 | headers: None, 99 | version: None, 100 | message: None, 101 | trace: Some(error.to_string()), 102 | } 103 | } 104 | 105 | pub(crate) fn new_with_status_context(status: StatusCode, context: C) -> Self 106 | where 107 | C: Display + Send + Sync + 'static, 108 | { 109 | Self { 110 | status: Some(status), 111 | headers: None, 112 | version: None, 113 | message: Some(context.to_string()), 114 | trace: None, 115 | } 116 | } 117 | 118 | pub(crate) fn new_with_status_err_context( 119 | status: StatusCode, 120 | error: E, 121 | context: C, 122 | ) -> Self 123 | where 124 | E: std::error::Error + Send + Sync + 'static, 125 | C: Display + Send + Sync + 'static, 126 | { 127 | Self { 128 | status: Some(status), 129 | headers: None, 130 | version: None, 131 | message: Some(context.to_string()), 132 | trace: Some(error.to_string()), 133 | } 134 | } 135 | 136 | pub(crate) fn into_inner_response(self, verbose: bool) -> InnerResponse { 137 | // Always start with user-created message 138 | let mut msg: String = self.message.unwrap_or_else(|| "".to_string()); 139 | 140 | if verbose { 141 | // must be a less awkward way to do this. 142 | if let Some(trace) = self.trace { 143 | #[allow(clippy::comparison_to_empty)] 144 | if msg != "" { 145 | msg = msg + "\n" + &trace; 146 | } else { 147 | msg = trace; 148 | } 149 | } 150 | } 151 | 152 | // as a default, set header to content-type text/plain if there's a message or trace. 153 | let mut headers = self.headers.unwrap_or_else(HeaderMap::new); 154 | if !msg.is_empty() { 155 | headers.insert(http::header::CONTENT_TYPE, "text/plain".parse().unwrap()); 156 | } 157 | 158 | 159 | InnerResponse { 160 | status: self.status.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), 161 | headers, 162 | version: self.version.unwrap_or(Version::HTTP_11), 163 | body: msg.into(), 164 | } 165 | } 166 | 167 | /// Set status of a Glitch 168 | pub fn set_status(&mut self, status: http::StatusCode) { 169 | self.status = Some(status); 170 | } 171 | 172 | /// Add a message to a Glitch 173 | pub fn set_message(&mut self, message: &str) { 174 | self.message = Some(message.into()); 175 | } 176 | 177 | /// Convenience method for sending a 400 178 | pub fn bad_request() -> Self { 179 | Self { 180 | status: Some(StatusCode::BAD_REQUEST), 181 | headers: None, 182 | version: None, 183 | message: None, 184 | trace: None, 185 | } 186 | } 187 | 188 | /// Convenience method for sending a 500 189 | pub fn internal_server_error() -> Self { 190 | Self { 191 | status: None, 192 | headers: None, 193 | version: None, 194 | message: None, 195 | trace: None, 196 | } 197 | } 198 | } 199 | 200 | // Context trait. Will set the `message` field in a glitch 201 | // Design from anyhow 202 | 203 | mod private { 204 | pub trait Sealed {} 205 | 206 | impl Sealed for std::result::Result where E: std::error::Error + Send + Sync + 'static {} 207 | 208 | impl Sealed for Option {} 209 | } 210 | 211 | /// GlitchExt makes it easy to chain onto a Result or Option, and convert into a Glitch. 212 | pub trait GlitchExt: private::Sealed { 213 | /// chain with `.glitch()?`, sets a Glitch with empty body. 214 | fn glitch(self, status: StatusCode) -> std::result::Result; 215 | 216 | /// chain with `.glitch_ctx(, "your_msg")?`, sets a Glitch with message in body. 217 | fn glitch_ctx(self, status: StatusCode, ctx: C) -> std::result::Result 218 | where 219 | C: Display + Send + Sync + 'static; 220 | 221 | /// chain with `.glitch_ctx(, || x.to_string())?`, sets a Glitch with message in body. 222 | /// 223 | /// Use when your context is set using a function, instead of just a value. 224 | fn glitch_with_ctx(self, status: StatusCode, f: F) -> std::result::Result 225 | where 226 | C: Display + Send + Sync + 'static, 227 | F: FnOnce() -> C; 228 | } 229 | 230 | impl GlitchExt for std::result::Result 231 | where 232 | E: std::error::Error + Send + Sync + 'static, 233 | { 234 | fn glitch(self, status: StatusCode) -> std::result::Result { 235 | self.map_err(|_| { 236 | let mut g = Glitch::new(); 237 | g.set_status(status); 238 | g 239 | }) 240 | } 241 | 242 | fn glitch_ctx(self, status: StatusCode, context: C) -> std::result::Result 243 | where 244 | C: Display + Send + Sync + 'static, 245 | { 246 | self.map_err(|error| Glitch::new_with_status_err_context(status, error, context)) 247 | } 248 | 249 | fn glitch_with_ctx(self, status: StatusCode, f: F) -> std::result::Result 250 | where 251 | C: Display + Send + Sync + 'static, 252 | F: FnOnce() -> C, 253 | { 254 | self.map_err(|error| Glitch::new_with_status_err_context(status, error, f())) 255 | } 256 | } 257 | 258 | impl GlitchExt for Option { 259 | fn glitch(self, status: StatusCode) -> std::result::Result { 260 | self.ok_or_else(|| { 261 | let mut g = Glitch::new(); 262 | g.set_status(status); 263 | g 264 | }) 265 | } 266 | 267 | fn glitch_ctx(self, status: StatusCode, context: C) -> std::result::Result 268 | where 269 | C: Display + Send + Sync + 'static, 270 | { 271 | self.ok_or_else(|| Glitch::new_with_status_context(status, context)) 272 | } 273 | 274 | fn glitch_with_ctx(self, status: StatusCode, f: F) -> std::result::Result 275 | where 276 | C: Display + Send + Sync + 'static, 277 | F: FnOnce() -> C, 278 | { 279 | self.ok_or_else(|| Glitch::new_with_status_context(status, f())) 280 | } 281 | } 282 | 283 | /// Convenience macro for creating a Glitch. 284 | /// 285 | /// `glitch!()`: 500 286 | /// `glitch!(StatusCode::BadRequest)`: 400 287 | /// `glitch!(StatusCode::BadRequest, "custom error")`: 400 with message in body 288 | #[macro_export] 289 | macro_rules! glitch ( 290 | () => { 291 | Glitch::internal_server_error(); 292 | }; 293 | ($code:expr) => { 294 | { 295 | let mut g= Glitch::new(); 296 | g.set_status($code); 297 | g 298 | } 299 | }; 300 | ($code:expr, $context:expr) => { 301 | { 302 | let mut g= Glitch::new(); 303 | g.set_status($code); 304 | g.set_message($context); 305 | g 306 | } 307 | }; 308 | ); 309 | 310 | #[macro_export] 311 | /// This one panics! 312 | /// 313 | /// Convenience macro for creating a Glitch. 314 | /// 315 | /// `glitch_code!()`: 500 316 | /// `glitch_code!(400)`: 400 317 | /// `glitch_code!(400, "custom error")`: 400 with message in body 318 | macro_rules! glitch_code ( 319 | () => { 320 | Glitch::internal_server_error(); 321 | }; 322 | ($code:expr) => { 323 | { 324 | let mut g= Glitch::new(); 325 | g.status(StatusCode::from_u16($code).unwrap()); 326 | g 327 | } 328 | }; 329 | ($code:expr, $context:expr) => { 330 | { 331 | let mut g= Glitch::new(); 332 | g.set_status(StatusCode::from_u16($code).unwrap()); 333 | g.set_message($context); 334 | g 335 | } 336 | }; 337 | ); 338 | -------------------------------------------------------------------------------- /src/server/decode.rs: -------------------------------------------------------------------------------- 1 | // TODO Handle all of the headers. See hyper src/proto/h1/role.rs 2 | // - transfer encoding 3 | // - connection 4 | // - expect 5 | // - upgrade 6 | // etc. 7 | 8 | use futures_lite::{io::BufReader, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; 9 | use http::{ 10 | request::Builder, 11 | header::{self, HeaderName, HeaderValue}, 12 | }; 13 | use std::fmt; 14 | use tracing::debug; 15 | 16 | use crate::body::Body; 17 | use crate::chunked::ChunkedDecoder; 18 | use crate::Request; 19 | 20 | use super::response_writer::InnerResponse; 21 | use super::error::ServerError; 22 | 23 | const LF: u8 = b'\n'; 24 | 25 | const SUPPORTED_TRANSFER_ENCODING: [&[u8]; 2] = [b"chunked", b"identity"]; 26 | 27 | /// Decode and http request 28 | /// 29 | /// Errors are bubbled up and handled in `accept`, the possible decode errors and the error handler 30 | /// are defined in this module. 31 | /// 32 | /// `None` means that no request was read. 33 | pub(crate) async fn decode(mut io: IO) -> Result, DecodeFail> 34 | where 35 | IO: AsyncRead + AsyncWrite + Clone + Unpin + Send + Sync + 'static, 36 | { 37 | use DecodeFail::*; 38 | 39 | let mut reader = BufReader::new(io.clone()); 40 | let mut buf = Vec::new(); 41 | let mut headers = [httparse::EMPTY_HEADER; 16]; 42 | let mut httparse_req = httparse::Request::new(&mut headers); 43 | 44 | // Keep reading bytes from the stream until we hit the end of the head. 45 | loop { 46 | let bytes_read = reader 47 | .read_until(LF, &mut buf) 48 | .await 49 | .map_err(ConnectionLost)?; 50 | 51 | // No bytes read, no request. 52 | if bytes_read == 0 { 53 | return Ok(None); 54 | } 55 | 56 | // We've hit the end delimiter of the head. 57 | let idx = buf.len() - 1; 58 | if idx >= 3 && &buf[idx - 3..=idx] == b"\r\n\r\n" { 59 | break; 60 | } 61 | } 62 | 63 | // Convert head buf into an httparse instance, and validate. 64 | let status = httparse_req.parse(&buf).map_err(HttpHeadParse)?; 65 | if status.is_partial() { 66 | return Err(HttpMalformedHead); 67 | }; 68 | 69 | // Check that req basics are here 70 | let method = http::Method::from_bytes(httparse_req.method.ok_or(HttpNoMethod)?.as_bytes()) 71 | .map_err(HttpMethod)?; 72 | let version = if httparse_req.version.ok_or(HttpNoVersion)? == 1 { 73 | //TODO keep_alive = true, is_http_11 = true 74 | http::Version::HTTP_11 75 | } else { 76 | //TODO keep_alive = false, is_http_11 = false 77 | //http::Version::HTTP_10 78 | return Err(Http10NotSupported); 79 | }; 80 | 81 | // Start with the basic request build, so we can add headers directly. 82 | let mut req = http::request::Builder::new(); 83 | 84 | // Now check headers for special cases (e.g. content-length, host), and append all headers 85 | // TODO check hyper for all the subtleties 86 | let mut content_length = None; 87 | let mut has_host = false; 88 | let mut is_te = false; 89 | let mut is_chunked = false; 90 | #[allow(clippy::borrow_interior_mutable_const)] // TODO see if I can remove this later 91 | for header in httparse_req.headers.iter() { 92 | if header.name == header::CONTENT_LENGTH { 93 | content_length = Some( 94 | std::str::from_utf8(header.value) 95 | .map_err(|_| HttpInvalidContentLength)? 96 | .parse::() 97 | .map_err(|_| HttpInvalidContentLength)?, 98 | ); 99 | } else if header.name == header::TRANSFER_ENCODING { 100 | // return error if transfer encoding not supported 101 | // TODO this allocates to lowercase ascii. fix? 102 | if !SUPPORTED_TRANSFER_ENCODING.contains(&header.value.to_ascii_lowercase().as_slice()) 103 | { 104 | return Err(HttpUnsupportedTransferEncoding); 105 | } 106 | 107 | is_te = true; 108 | is_chunked = String::from_utf8_lossy(header.value) 109 | .trim() 110 | .eq_ignore_ascii_case("chunked"); 111 | } else if header.name == header::HOST { 112 | has_host = true; 113 | } 114 | 115 | req.headers_mut().expect("Request builder error").append( 116 | HeaderName::from_bytes(header.name.as_bytes()).map_err(HttpHeaderName)?, 117 | HeaderValue::from_bytes(header.value).map_err(HttpHeaderValue)?, 118 | ); 119 | } 120 | 121 | handle_100_continue(&req, &mut io).await?; 122 | 123 | // Now handle more complex parts of HTTP protocol 124 | 125 | // Handle path according to https://tools.ietf.org/html/rfc2616#section-5.2 126 | // Tophat ignores the host when determining resource identified. However, the Host header is 127 | // still required. 128 | if !has_host { 129 | return Err(HttpNoHost); 130 | } 131 | let path = httparse_req.path.ok_or(HttpNoPath)?; 132 | 133 | // Handling content-length v. transfer-encoding: 134 | // TODO double-check with https://tools.ietf.org/html/rfc7230#section-3.3.3 135 | let content_length = content_length.unwrap_or(0); 136 | 137 | // Decode body as fixed_body or as chunked 138 | let body = if is_te && is_chunked { 139 | let mut body = Body::empty(); 140 | let trailer_sender = body.send_trailers(); 141 | let reader = BufReader::new(ChunkedDecoder::new(reader, trailer_sender)); 142 | body.set_inner(reader, None); 143 | body 144 | } else { 145 | Body::from_reader(reader.take(content_length as u64), Some(content_length)) 146 | }; 147 | 148 | // Finally build the rest of the req 149 | let req = req 150 | .method(method) 151 | .version(version) 152 | .uri(path) 153 | .body(body) 154 | .map_err(|_| HttpRequestBuild)?; 155 | 156 | Ok(Some(req)) 157 | } 158 | 159 | const EXPECT_HEADER_VALUE: &[u8] = b"100-continue"; 160 | const EXPECT_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n"; 161 | 162 | // This implementation s from async-h1, and should be spec-compliant, but async-h1 moved 163 | // to another way (that may/may not be better?) that requires use of `spawn`. See 164 | // https://tools.ietf.org/html/rfc7231#section-6.2.1 and 165 | // https://github.com/http-rs/async-h1/issues/135 166 | async fn handle_100_continue(req: &Builder, wtr: &mut W) -> Result<(), DecodeFail> 167 | where 168 | W: AsyncWrite + Unpin 169 | { 170 | let expect_header = req.headers_ref() 171 | .and_then(|hs| hs.get(header::EXPECT)) 172 | .map(|h| h.as_bytes()); 173 | 174 | if let Some(EXPECT_HEADER_VALUE) = expect_header { 175 | wtr.write_all(EXPECT_RESPONSE) 176 | .await 177 | .map_err(DecodeFail::ConnectionLost)?; 178 | } 179 | 180 | Ok(()) 181 | } 182 | 183 | // Internal failures. If one leads to an external error to bubble up, convert to a public error in 184 | // the `error` module. 185 | #[derive(Debug)] 186 | pub(crate) enum DecodeFail { 187 | // These errors should result in a connection closure 188 | ConnectionLost(std::io::Error), 189 | HttpMalformedHead, 190 | HttpUnsupportedTransferEncoding, 191 | 192 | // Below failures should be handled with a Response, but not with connection closure. 193 | 194 | // TODO check that these are actually errors, and not just something to handle 195 | HttpNoPath, 196 | HttpNoMethod, 197 | HttpNoVersion, 198 | HttpNoHost, 199 | HttpInvalidContentLength, 200 | HttpRequestBuild, 201 | Http10NotSupported, 202 | 203 | // conversions related to http and httparse lib 204 | HttpHeadParse(httparse::Error), 205 | HttpMethod(http::method::InvalidMethod), 206 | HttpHeaderName(http::header::InvalidHeaderName), 207 | HttpHeaderValue(http::header::InvalidHeaderValue), 208 | } 209 | 210 | impl fmt::Display for DecodeFail { 211 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 212 | use DecodeFail::*; 213 | match self { 214 | ConnectionLost(err) => write!(f, "Connection Lost: {}", err), 215 | HttpMalformedHead => write!(f, "Http parse malformed head"), 216 | HttpUnsupportedTransferEncoding => write!(f, "Http transfer encoding not supported"), 217 | HttpNoPath => write!(f, "Http no path found"), 218 | HttpNoMethod => write!(f, "Http no method found"), 219 | HttpNoVersion => write!(f, "Http no version found"), 220 | HttpNoHost => write!(f, "Http no host found"), 221 | HttpInvalidContentLength => write!(f, "Http invalid content length"), 222 | HttpRequestBuild => write!(f, "Http request could not be built"), 223 | Http10NotSupported => write!(f, "Http version 1.0 not supported"), 224 | HttpHeadParse(err) => write!(f, "Http header parsing error: {}", err), 225 | HttpMethod(err) => write!(f, "Http Method error: {}", err), 226 | HttpHeaderName(err) => write!(f, "Http Header name error: {}", err), 227 | HttpHeaderValue(err) => write!(f, "Http Header value error: {}", err), 228 | } 229 | } 230 | } 231 | 232 | pub(crate) fn fail_to_response_and_log(fail: &DecodeFail) -> Option { 233 | use DecodeFail::*; 234 | 235 | // TODO improve logging message 236 | debug!("Decode error: {} ", fail); 237 | 238 | match fail { 239 | ConnectionLost(_) => None, 240 | HttpUnsupportedTransferEncoding => Some(InnerResponse::not_implemented()), 241 | Http10NotSupported => Some(InnerResponse::version_not_supported()), 242 | _ => Some(InnerResponse::bad_request()), 243 | } 244 | } 245 | 246 | pub(crate) fn fail_to_crate_err(fail: DecodeFail) -> Option { 247 | use DecodeFail::*; 248 | 249 | // TODO improve logging message 250 | debug!("Decode crate-level error: {} ", fail); 251 | 252 | match fail { 253 | //ConnectionLost(err) => Some(Error::ConnectionLost(err)), 254 | HttpUnsupportedTransferEncoding => Some(ServerError::ConnectionClosedUnsupportedTransferEncoding), 255 | _ => None, 256 | } 257 | } 258 | 259 | #[cfg(test)] 260 | mod test { 261 | use super::*; 262 | use crate::util::Cursor; 263 | use smol; 264 | 265 | #[test] 266 | fn test_handle_100_continue_does_nothing_with_no_header() { 267 | let req = http::request::Builder::new(); 268 | let mut io = Cursor::new(Vec::new()); 269 | smol::block_on(async { 270 | let result = handle_100_continue(&req, &mut io).await; 271 | assert_eq!( 272 | std::str::from_utf8(&io.into_inner()).unwrap(), 273 | "", 274 | ); 275 | 276 | assert!(result.is_ok()) 277 | }); 278 | } 279 | 280 | #[test] 281 | fn test_handle_100_continue_sends_header_if_expects_is_right() { 282 | let mut req = http::request::Builder::new(); 283 | req.headers_mut().expect("Request builder error").append( 284 | HeaderName::from_bytes(b"expect").unwrap(), 285 | HeaderValue::from_bytes(b"100-continue").unwrap(), 286 | ); 287 | let mut io = Cursor::new(Vec::new()); 288 | smol::block_on(async { 289 | let result = handle_100_continue(&req, &mut io).await; 290 | assert_eq!( 291 | std::str::from_utf8(&io.into_inner()).unwrap(), 292 | "HTTP/1.1 100 Continue\r\n\r\n", 293 | ); 294 | 295 | assert!(result.is_ok()) 296 | }); 297 | } 298 | 299 | #[test] 300 | fn test_handle_100_continue_sends_header_if_expects_is_wrong() { 301 | let mut req = http::request::Builder::new(); 302 | req.headers_mut().expect("Request builder error").append( 303 | HeaderName::from_bytes(b"expect").unwrap(), 304 | HeaderValue::from_bytes(b"111-wrong").unwrap(), 305 | ); 306 | let mut io = Cursor::new(Vec::new()); 307 | smol::block_on(async { 308 | let result = handle_100_continue(&req, &mut io).await; 309 | assert_eq!( 310 | std::str::from_utf8(&io.into_inner()).unwrap(), 311 | "", 312 | ); 313 | 314 | assert!(result.is_ok()) 315 | }); 316 | } 317 | } 318 | -------------------------------------------------------------------------------- /tests/server_basic.rs: -------------------------------------------------------------------------------- 1 | mod chunked_text_big; 2 | mod mock; 3 | 4 | use http::{ 5 | header::{self, HeaderName, HeaderValue}, 6 | method::Method, 7 | Uri, Version, 8 | }; 9 | use tophat::{server::accept, Body}; 10 | 11 | use mock::{Cursor, Client}; 12 | 13 | const RESP_200: &str = "HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n"; 14 | const RESP_400: &str = "HTTP/1.1 400 Bad Request\r\ncontent-length: 0\r\n\r\n"; 15 | 16 | #[test] 17 | fn test_request_empty_body() { 18 | smol::block_on(async { 19 | let testclient = Client::new( 20 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\n\r\n", 21 | RESP_200, 22 | ); 23 | 24 | accept(testclient.clone(), |_req, resp_wtr| async move { 25 | // Won't compile if done is not returned in Ok! 26 | let done = resp_wtr.send().await.unwrap(); 27 | 28 | Ok(done) 29 | }) 30 | .await 31 | .unwrap(); 32 | 33 | testclient.assert(); 34 | }); 35 | } 36 | 37 | #[test] 38 | fn test_request_basic_with_body_and_query() { 39 | smol::block_on(async { 40 | let testclient = Client::new( 41 | "GET /foo/bar?one=two HTTP/1.1\r\nHost: example.org\r\nContent-Length: 6\r\n\r\ntophat", 42 | "HTTP/1.1 200 OK\r\ncontent-length: 12\r\ncontent-type: text/plain\r\n\r\nHello tophat", 43 | ); 44 | 45 | accept(testclient.clone(), |req, mut resp_wtr| async move { 46 | // some basic parsing tests 47 | assert_eq!(req.uri().path(), Uri::from_static("/foo/bar")); 48 | assert_eq!(req.uri().query(), Some("one=two")); 49 | assert_eq!(req.version(), Version::HTTP_11); 50 | assert_eq!(req.method(), Method::GET); 51 | assert_eq!( 52 | req.headers().get(header::CONTENT_LENGTH), 53 | Some(&HeaderValue::from_bytes(b"6").unwrap()) 54 | ); 55 | assert_eq!( 56 | req.headers().get(header::HOST), 57 | Some(&HeaderValue::from_bytes(b"example.org").unwrap()) 58 | ); 59 | 60 | let body = req.into_body().into_string().await.unwrap(); 61 | let res_body = format!("Hello {}", body); 62 | 63 | resp_wtr.set_body(res_body.into()); 64 | resp_wtr.insert_header("content-type", "text/plain".parse().unwrap()); 65 | 66 | let done = resp_wtr.send().await.unwrap(); 67 | 68 | Ok(done) 69 | }) 70 | .await 71 | .unwrap(); 72 | 73 | testclient.assert(); 74 | }); 75 | } 76 | #[test] 77 | fn test_request_missing_method() { 78 | smol::block_on(async { 79 | let testclient = Client::new( 80 | "/foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 81 | RESP_400, 82 | ); 83 | 84 | accept(testclient.clone(), |_req, resp_wtr| async move { 85 | resp_wtr.send().await 86 | }) 87 | .await 88 | .unwrap(); 89 | 90 | testclient.assert(); 91 | }); 92 | } 93 | 94 | #[test] 95 | fn test_request_missing_host() { 96 | smol::block_on(async { 97 | let testclient = Client::new( 98 | "GET /foo/bar HTTP/1.1\r\nContent-Length: 0\r\n\r\n", 99 | RESP_400, 100 | ); 101 | 102 | accept(testclient.clone(), |_req, resp_wtr| async move { 103 | resp_wtr.send().await 104 | }) 105 | .await 106 | .unwrap(); 107 | 108 | testclient.assert(); 109 | }); 110 | } 111 | 112 | #[test] 113 | // ignore host, should return abs_path or AbsoluteURI from uri 114 | fn test_request_path() { 115 | smol::block_on(async { 116 | // good uri path 117 | let testclient = Client::new( 118 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 119 | RESP_200, 120 | ); 121 | 122 | accept(testclient.clone(), |req, resp_wtr| async move { 123 | assert_eq!(*req.uri(), Uri::from_static("/foo/bar")); 124 | assert_eq!(*req.uri().path(), Uri::from_static("/foo/bar")); 125 | resp_wtr.send().await 126 | }) 127 | .await 128 | .unwrap(); 129 | 130 | testclient.assert(); 131 | 132 | // good absolute uri, ignores host 133 | let testclient = Client::new( 134 | "GET https://wunder.org/foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 135 | RESP_200, 136 | ); 137 | 138 | accept(testclient.clone(), |req, resp_wtr| async move { 139 | assert_eq!(*req.uri(), Uri::from_static("https://wunder.org/foo/bar")); 140 | assert_eq!(*req.uri().path(), Uri::from_static("/foo/bar")); 141 | resp_wtr.send().await 142 | }) 143 | .await 144 | .unwrap(); 145 | 146 | testclient.assert(); 147 | 148 | // bad uri path 149 | let testclient = Client::new( 150 | "GET foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 151 | RESP_400, 152 | ); 153 | 154 | accept(testclient.clone(), |_req, resp_wtr| async move { 155 | resp_wtr.send().await 156 | }) 157 | .await 158 | .unwrap(); 159 | 160 | testclient.assert(); 161 | }); 162 | } 163 | 164 | #[test] 165 | fn test_request_version() { 166 | // malformed version 167 | smol::block_on(async { 168 | let testclient = Client::new( 169 | "GET /foo/bar HTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 170 | RESP_400, 171 | ); 172 | 173 | accept(testclient.clone(), |_req, resp_wtr| async move { 174 | resp_wtr.send().await 175 | }) 176 | .await 177 | .unwrap(); 178 | 179 | testclient.assert(); 180 | }); 181 | 182 | // version 1.0 not supported 183 | smol::block_on(async { 184 | let testclient = Client::new( 185 | "GET /foo/bar HTTP/1.0\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 186 | "HTTP/1.1 505 HTTP Version Not Supported\r\ncontent-length: 0\r\n\r\n", 187 | ); 188 | 189 | accept(testclient.clone(), |_req, resp_wtr| async move { 190 | resp_wtr.send().await 191 | }) 192 | .await 193 | .unwrap(); 194 | 195 | testclient.assert(); 196 | }); 197 | } 198 | 199 | // sends message _ands_ closes connection 200 | #[test] 201 | fn test_transfer_encoding_unsupported() { 202 | smol::block_on(async { 203 | let testclient = Client::new( 204 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\nTransfer-Encoding: gzip\r\n\r\n", 205 | "HTTP/1.1 501 Not Implemented\r\ncontent-length: 0\r\n\r\n", 206 | ); 207 | 208 | let res = accept(testclient.clone(), |_req, resp_wtr| async move { 209 | resp_wtr.send().await 210 | }) 211 | .await; 212 | 213 | match res { 214 | Ok(_) => panic!(), 215 | Err(err) => match err { 216 | tophat::server::ServerError::ConnectionClosedUnsupportedTransferEncoding => (), 217 | _ => panic!(), 218 | }, 219 | } 220 | 221 | testclient.assert(); 222 | }); 223 | } 224 | 225 | #[test] 226 | // TODO handle transfer-encoding chunked and content-length clash 227 | #[ignore] // temporary 228 | fn test_transfer_encoding_content_length() { 229 | smol::block_on(async { 230 | let testclient = Client::new( 231 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\nTransfer-Encoding: chunked\r\n\r\n", 232 | RESP_400, 233 | ); 234 | 235 | accept(testclient.clone(), |_req, resp_wtr| async move { 236 | resp_wtr.send().await 237 | }) 238 | .await 239 | .unwrap(); 240 | 241 | testclient.assert(); 242 | }); 243 | } 244 | 245 | #[test] 246 | fn test_dont_allow_user_set_body_type_header() { 247 | // Even if user sets the header for content-length or transfer-encoding, just ignore because 248 | // the encoding step will set it automatically 249 | // 250 | // Just test the two conflicting cases 251 | smol::block_on(async { 252 | let testclient = Client::new( 253 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 254 | RESP_200, 255 | ); 256 | 257 | accept(testclient.clone(), |_req, mut resp_wtr| async move { 258 | resp_wtr.append_header(header::TRANSFER_ENCODING, "chunked".parse().unwrap()); 259 | resp_wtr.send().await 260 | }) 261 | .await 262 | .unwrap(); 263 | 264 | testclient.assert(); 265 | }); 266 | 267 | smol::block_on(async { 268 | let testclient = Client::new( 269 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 270 | "HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ncontent-type: application/octet-stream\r\n\r\n0\r\n\r\n", 271 | ); 272 | 273 | accept(testclient.clone(), |_req, mut resp_wtr| async move { 274 | resp_wtr.set_body(Body::from_reader(Cursor::new(""), None)); 275 | resp_wtr.append_header(header::CONTENT_LENGTH, "20".parse().unwrap()); 276 | resp_wtr.send().await 277 | }) 278 | .await 279 | .unwrap(); 280 | 281 | testclient.assert(); 282 | }); 283 | } 284 | 285 | #[test] 286 | fn test_response_date() { 287 | // make sure that date isn't doubled if it's also set in response 288 | // also make sure that the date header was passed through 289 | smol::block_on(async { 290 | let testclient = Client::new( 291 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\nTransfer-Encoding: chunked\r\n\r\n", 292 | RESP_200, 293 | ); 294 | 295 | accept(testclient.clone(), |_req, mut resp_wtr| async move { 296 | resp_wtr.append_header( 297 | header::DATE, 298 | "Wed, 21 Oct 2015 07:28:00 GMT".parse().unwrap(), 299 | ); 300 | resp_wtr.send().await 301 | }) 302 | .await 303 | .unwrap(); 304 | 305 | // One Date header should be stripped out by Client 306 | testclient.assert_with_resp_date("Wed, 21 Oct 2015 07:28:00 GMT"); 307 | }); 308 | } 309 | 310 | #[test] 311 | fn test_set_content_type_mime() { 312 | smol::block_on(async { 313 | let testclient = Client::new( 314 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 315 | "HTTP/1.1 200 OK\r\ncontent-length: 0\r\ncontent-type: text/plain\r\n\r\n", 316 | ); 317 | 318 | accept(testclient.clone(), |_req, mut resp_wtr| async move { 319 | resp_wtr.append_header(header::CONTENT_TYPE, "text/plain".parse().unwrap()); 320 | resp_wtr.send().await 321 | }) 322 | .await 323 | .unwrap(); 324 | 325 | // One Date header should be stripped out by Client 326 | testclient.assert(); 327 | }); 328 | } 329 | 330 | #[test] 331 | fn test_decode_transfer_encoding_chunked() { 332 | smol::block_on(async { 333 | let testclient = Client::new( 334 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nTransfer-Encoding: chunked\r\n\r\n\ 335 | 7\r\n\ 336 | Mozilla\r\n\ 337 | 9\r\n\ 338 | Developer\r\n\ 339 | 7\r\n\ 340 | Network\r\n\ 341 | 0\r\n\ 342 | Expires: Wed, 21 Oct 2015 07:28:00 GMT\r\n\ 343 | \r\n", 344 | RESP_200, 345 | ); 346 | 347 | accept(testclient.clone(), |req, resp_wtr| async move { 348 | // If you want to wait for trailer, need to use this method. 349 | // Reading body and trailer separately will run into borrow errors 350 | let (body, trailer) = req.into_body().into_string_with_trailer().await.unwrap(); 351 | 352 | let trailer = trailer.unwrap().unwrap(); 353 | 354 | assert_eq!(body, "MozillaDeveloperNetwork"); 355 | assert_eq!( 356 | trailer.headers.iter().collect::>(), 357 | vec![( 358 | &HeaderName::from_bytes(b"Expires").unwrap(), 359 | &HeaderValue::from_bytes(b"Wed, 21 Oct 2015 07:28:00 GMT").unwrap(), 360 | )] 361 | ); 362 | 363 | resp_wtr.send().await 364 | }) 365 | .await 366 | .unwrap(); 367 | 368 | testclient.assert(); 369 | }); 370 | 371 | // no trailer 372 | smol::block_on(async { 373 | let testclient = Client::new( 374 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nTransfer-Encoding: chunked\r\n\r\n\ 375 | 7\r\n\ 376 | Mozilla\r\n\ 377 | 9\r\n\ 378 | Developer\r\n\ 379 | 7\r\n\ 380 | Network\r\n\ 381 | 0\r\n\ 382 | \r\n", 383 | RESP_200, 384 | ); 385 | 386 | accept(testclient.clone(), |req, resp_wtr| async move { 387 | // If you want to wait for trailer, need to use this method. 388 | // Reading body and trailer separately will run into borrow errors 389 | let (body, trailer) = req.into_body().into_string_with_trailer().await.unwrap(); 390 | 391 | let trailer = trailer.unwrap().unwrap(); 392 | 393 | assert_eq!(body, "MozillaDeveloperNetwork"); 394 | assert!(trailer.headers.is_empty()); 395 | 396 | resp_wtr.send().await 397 | }) 398 | .await 399 | .unwrap(); 400 | 401 | testclient.assert(); 402 | }); 403 | } 404 | 405 | #[test] 406 | fn test_encode_transfer_encoding_chunked() { 407 | smol::block_on(async { 408 | // 13 is D in hexadecimal. 409 | // Need two writes because there's a chunk and then there's the end. 410 | let testclient = Client::new_with_writes( 411 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 412 | "HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\ncontent-type: application/octet-stream\r\n\r\nD\r\nHello tophat!\r\n0\r\n\r\n", 413 | 2, 414 | ); 415 | 416 | accept(testclient.clone(), |_req, mut resp_wtr| async move { 417 | let body_str = Cursor::new("Hello tophat!"); 418 | resp_wtr.set_body(Body::from_reader(body_str, None)); 419 | 420 | resp_wtr.send().await 421 | }) 422 | .await 423 | .unwrap(); 424 | 425 | testclient.assert(); 426 | }); 427 | } 428 | 429 | #[test] 430 | fn test_encode_transfer_encoding_chunked_big() { 431 | smol::block_on(async { 432 | let testclient = Client::new_with_writes( 433 | "GET /foo/bar HTTP/1.1\r\nHost: example.org\r\nContent-Length: 0\r\n\r\n", 434 | chunked_text_big::RESPONSE, 435 | 3, 436 | ); 437 | 438 | accept(testclient.clone(), |_req, mut resp_wtr| async move { 439 | let body_str = Cursor::new(chunked_text_big::TEXT); 440 | resp_wtr.set_body(Body::from_reader(body_str, None)); 441 | 442 | resp_wtr.send().await 443 | }) 444 | .await 445 | .unwrap(); 446 | 447 | testclient.assert(); 448 | }); 449 | } 450 | -------------------------------------------------------------------------------- /src/server/cors.rs: -------------------------------------------------------------------------------- 1 | // Cors module based on warp's. 2 | 3 | //! Cors module 4 | //! 5 | //! Handles pre-flight 6 | //! 7 | //! Currently a super-simple, not-complete implementation. 8 | //! 9 | //! Does _not_ check for correctness of request headers and content-type. 10 | // (Does anybody? I checked warp and iron cors middleware, I don't think they do. 11 | //! 12 | //! Not yet an ergonomic api. (No builder) 13 | //! 14 | //! ## Simple cors 15 | //! Only checks for client's Origin header, and will respond with a `Access-Control-Allow-Origin` 16 | //! header only, with the specified allowed origins. 17 | //! 18 | //! ## Preflight cors 19 | //! - client method: is `Options` 20 | //! - client header: origin 21 | //! - client header: access-control-request-method 22 | //! - client header: access-control-request-headers 23 | //! 24 | //! - server status: 200 OK 25 | //! - server header: access-control-allow-origin 26 | //! - server header: access-control-allow-methods 27 | //! - server header: access-control-allow-headers 28 | //! - server header: access-control-max-age (86400s is one day) 29 | 30 | use futures_util::io::{AsyncRead, AsyncWrite}; 31 | use headers::{ 32 | AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMapExt, 33 | Origin, 34 | }; 35 | use http::{ 36 | header::{self, HeaderMap, HeaderName, HeaderValue}, 37 | Method, StatusCode, 38 | }; 39 | use std::collections::HashSet; 40 | use std::convert::TryFrom; 41 | 42 | use crate::{ 43 | server::{ 44 | glitch::{Glitch, Result}, 45 | ResponseWriter, 46 | }, 47 | Request, 48 | }; 49 | 50 | /// Build a Cors 51 | pub struct CorsBuilder { 52 | /// For preflight and simple, whether to add the access-control-allow-credentials header 53 | /// default false 54 | pub credentials: bool, 55 | /// For preflight only, allowed headers 56 | pub allowed_headers: HashSet, 57 | /// For preflight and simple, tell client what headers it can access 58 | pub exposed_headers: HashSet, 59 | /// For preflight only, max age 60 | pub max_age: Option, 61 | /// For preflight only, allowed methods 62 | pub methods: HashSet, 63 | /// For preflight and simple, allowed origins. Default is '*' 64 | pub origins: Option>, 65 | } 66 | 67 | impl CorsBuilder { 68 | /// Sets whether to add the `Access-Control-Allow-Credentials` header. 69 | pub fn allow_credentials(mut self, allow: bool) -> Self { 70 | self.credentials = allow; 71 | self 72 | } 73 | 74 | /// Adds a method to the existing list of allowed request methods. 75 | /// 76 | /// # Panics 77 | /// 78 | /// Panics if the provided argument is not a valid `http::Method`. 79 | pub fn allow_method(mut self, method: M) -> Self 80 | where 81 | http::Method: TryFrom, 82 | { 83 | let method = match TryFrom::try_from(method) { 84 | Ok(m) => m, 85 | _ => panic!("illegal Method"), 86 | }; 87 | self.methods.insert(method); 88 | self 89 | } 90 | 91 | /// Adds multiple methods to the existing list of allowed request methods. 92 | /// 93 | /// # Panics 94 | /// 95 | /// Panics if the provided argument is not a valid `http::Method`. 96 | pub fn allow_methods(mut self, methods: I) -> Self 97 | where 98 | I: IntoIterator, 99 | http::Method: TryFrom, 100 | { 101 | let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) { 102 | Ok(m) => m, 103 | _ => panic!("illegal Method"), 104 | }); 105 | self.methods.extend(iter); 106 | self 107 | } 108 | 109 | /// Adds a header to the list of allowed request headers. 110 | /// 111 | /// # Panics 112 | /// 113 | /// Panics if the provided argument is not a valid `http::header::HeaderName`. 114 | pub fn allow_header(mut self, header: H) -> Self 115 | where 116 | HeaderName: TryFrom, 117 | { 118 | let header = match TryFrom::try_from(header) { 119 | Ok(m) => m, 120 | _ => panic!("illegal Header"), 121 | }; 122 | self.allowed_headers.insert(header); 123 | self 124 | } 125 | 126 | /// Adds multiple headers to the list of allowed request headers. 127 | /// 128 | /// # Panics 129 | /// 130 | /// Panics if any of the headers are not a valid `http::header::HeaderName`. 131 | pub fn allow_headers(mut self, headers: I) -> Self 132 | where 133 | I: IntoIterator, 134 | HeaderName: TryFrom, 135 | { 136 | let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) { 137 | Ok(h) => h, 138 | _ => panic!("illegal Header"), 139 | }); 140 | self.allowed_headers.extend(iter); 141 | self 142 | } 143 | 144 | /// Adds a header to the list of exposed headers. 145 | /// 146 | /// # Panics 147 | /// 148 | /// Panics if the provided argument is not a valid `http::header::HeaderName`. 149 | pub fn expose_header(mut self, header: H) -> Self 150 | where 151 | HeaderName: TryFrom, 152 | { 153 | let header = match TryFrom::try_from(header) { 154 | Ok(m) => m, 155 | _ => panic!("illegal Header"), 156 | }; 157 | self.exposed_headers.insert(header); 158 | self 159 | } 160 | 161 | /// Adds multiple headers to the list of exposed headers. 162 | /// 163 | /// # Panics 164 | /// 165 | /// Panics if any of the headers are not a valid `http::header::HeaderName`. 166 | pub fn expose_headers(mut self, headers: I) -> Self 167 | where 168 | I: IntoIterator, 169 | HeaderName: TryFrom, 170 | { 171 | let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) { 172 | Ok(h) => h, 173 | _ => panic!("illegal Header"), 174 | }); 175 | self.exposed_headers.extend(iter); 176 | self 177 | } 178 | 179 | /// Sets that *any* `Origin` header is allowed. 180 | /// 181 | /// # Warning 182 | /// 183 | /// This can allow websites you didn't instead to access this resource, 184 | /// it is usually better to set an explicit list. 185 | pub fn allow_any_origin(mut self) -> Self { 186 | self.origins = None; 187 | self 188 | } 189 | 190 | /// Add an origin to the existing list of allowed `Origin`s. 191 | /// 192 | /// # Panics 193 | /// 194 | /// Panics if the provided argument is not a valid `Origin`. 195 | pub fn allow_origin(self, origin: impl IntoOrigin) -> Self { 196 | self.allow_origins(Some(origin)) 197 | } 198 | 199 | /// Add multiple origins to the existing list of allowed `Origin`s. 200 | /// 201 | /// # Panics 202 | /// 203 | /// Panics if the provided argument is not a valid `Origin`. 204 | pub fn allow_origins(mut self, origins: I) -> Self 205 | where 206 | I: IntoIterator, 207 | I::Item: IntoOrigin, 208 | { 209 | let iter = origins 210 | .into_iter() 211 | .map(IntoOrigin::into_origin) 212 | .map(|origin| { 213 | origin 214 | .to_string() 215 | .parse() 216 | .expect("Origin is always a valid HeaderValue") 217 | }); 218 | 219 | self.origins.get_or_insert_with(HashSet::new).extend(iter); 220 | 221 | self 222 | } 223 | 224 | /// Sets the `Access-Control-Max-Age` header. 225 | pub fn max_age(mut self, seconds: u64) -> Self { 226 | self.max_age = Some(seconds); 227 | self 228 | } 229 | 230 | /// Finish building a Cors 231 | pub fn finish(self) -> Cors { 232 | let exposed_headers = if self.exposed_headers.is_empty() { 233 | None 234 | } else { 235 | Some(self.exposed_headers.into_iter().collect()) 236 | }; 237 | 238 | Cors { 239 | credentials: self.credentials, 240 | allowed_headers: self.allowed_headers.iter().cloned().collect(), 241 | allowed_headers_set: self.allowed_headers, 242 | exposed_headers, 243 | max_age: self.max_age, 244 | methods: self.methods.iter().cloned().collect(), 245 | methods_set: self.methods, 246 | origins: self.origins, 247 | } 248 | } 249 | } 250 | 251 | /// Cors 252 | /// 253 | /// See module docs for more details 254 | #[derive(Clone)] 255 | pub struct Cors { 256 | /// For preflight and simple, whether to add the access-control-allow-credentials header 257 | /// default false 258 | credentials: bool, 259 | 260 | allowed_headers_set: HashSet, 261 | /// For preflight only, allowed headers 262 | allowed_headers: AccessControlAllowHeaders, 263 | 264 | /// For preflight and simple, tell client what headers it can access 265 | exposed_headers: Option, 266 | 267 | /// For preflight only, max age 268 | max_age: Option, 269 | 270 | methods_set: HashSet, 271 | /// For preflight only, allowed methods 272 | methods: AccessControlAllowMethods, 273 | /// For preflight and simple, allowed origins. Default is '*' 274 | /// When responding, just use the origin sent by client if it's in the allowed list. 275 | origins: Option>, 276 | } 277 | 278 | impl Cors { 279 | /// Build a Cors 280 | pub fn build() -> CorsBuilder { 281 | CorsBuilder { 282 | credentials: false, 283 | allowed_headers: HashSet::new(), 284 | exposed_headers: HashSet::new(), 285 | max_age: None, 286 | methods: HashSet::new(), 287 | origins: None, 288 | } 289 | } 290 | 291 | // `Options` method differentiates preflight from simple. Does not check for correctness of a 292 | // simple request. 293 | // 294 | // The design seems a little weird in terms of error handling; basically 295 | // - Ok means continuing to endpoint. This is for both simple cors and not cors 296 | // - Err means short-circuit. This is for preflight and invalid 297 | /// Validate Cors. 298 | /// 299 | /// - handles simple Cors 300 | /// - handles preflight. 301 | /// 302 | /// See example `cors` to set up properly as middleware. 303 | pub fn validate(&self, req: &Request, resp_wtr: &mut ResponseWriter) -> Result<()> 304 | where 305 | W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static, 306 | { 307 | let req_method = req.method(); 308 | let req_origin = req.headers().get(header::ORIGIN); 309 | 310 | match (req_method, req_origin) { 311 | (&Method::OPTIONS, Some(origin)) => { 312 | // Preflight checks 313 | if !self.is_origin_allowed(origin) { 314 | return Err(Glitch::bad_request()); 315 | // TODO error message? 316 | //Err(Forbidden::OriginNotAllowed); 317 | } 318 | 319 | let headers = req.headers(); 320 | 321 | if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) { 322 | if !self.is_method_allowed(req_method) { 323 | return Err(Glitch::bad_request()); 324 | // TODO error message? 325 | //Err(Forbidden::MethodNotAllowed); 326 | } 327 | } else { 328 | println!("hit"); 329 | return Err(Glitch::bad_request()); 330 | // TODO error message? 331 | // return Err(Forbidden::MethodNotAllowed); 332 | } 333 | 334 | if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) { 335 | // TODO error message? 336 | //let headers = req.headers() 337 | // .to_str() 338 | // .map_err(|_| Forbidden::HeaderNotAllowed)?; 339 | let headers = match req_headers.to_str() { 340 | Ok(h) => h, 341 | Err(_) => return Err(Glitch::bad_request()), 342 | }; 343 | for header in headers.split(',') { 344 | if !self.is_header_allowed(header) { 345 | return Err(Glitch::bad_request()); 346 | // TODO error message? 347 | //return Err(Forbidden::HeaderNotAllowed); 348 | } 349 | } 350 | } 351 | 352 | // If all checks successful, continue with headers for resp. 353 | // 354 | // NOTE it looks kind of weird, but a Glitch is used to have an early return for 355 | // preflight. 356 | // 357 | // set headers 358 | let mut resp = Glitch::new(); 359 | let mut headers = HeaderMap::new(); 360 | self.append_preflight_headers(&mut headers); 361 | // set allowed-origin header 362 | headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); 363 | 364 | resp.status = Some(StatusCode::OK); 365 | resp.headers = Some(headers); 366 | 367 | Err(resp) 368 | } 369 | (_, Some(origin)) => { 370 | // Simple 371 | if self.is_origin_allowed(origin) { 372 | // set common headers 373 | let mut headers = resp_wtr.response_mut().headers_mut(); 374 | self.append_common_headers(&mut headers); 375 | // set allowed-origin header 376 | resp_wtr.insert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); 377 | 378 | return Ok(()); 379 | } 380 | 381 | // If origin is not allowed 382 | Err(Glitch::bad_request()) 383 | } 384 | (_, _) => { 385 | // All other requests are not Cors 386 | Ok(()) 387 | } 388 | } 389 | } 390 | 391 | fn is_method_allowed(&self, header: &HeaderValue) -> bool { 392 | http::Method::from_bytes(header.as_bytes()) 393 | .map(|method| self.methods_set.contains(&method)) 394 | .unwrap_or(false) 395 | } 396 | 397 | fn is_header_allowed(&self, header: &str) -> bool { 398 | HeaderName::from_bytes(header.as_bytes()) 399 | .map(|header| self.allowed_headers_set.contains(&header)) 400 | .unwrap_or(false) 401 | } 402 | 403 | fn is_origin_allowed(&self, origin: &HeaderValue) -> bool { 404 | if let Some(ref allowed) = self.origins { 405 | allowed.contains(origin) 406 | } else { 407 | true 408 | } 409 | } 410 | 411 | fn append_preflight_headers(&self, headers: &mut HeaderMap) { 412 | self.append_common_headers(headers); 413 | 414 | headers.typed_insert(self.allowed_headers.clone()); 415 | headers.typed_insert(self.methods.clone()); 416 | 417 | if let Some(max_age) = self.max_age { 418 | headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into()); 419 | } 420 | } 421 | 422 | fn append_common_headers(&self, headers: &mut HeaderMap) { 423 | if self.credentials { 424 | headers.insert( 425 | header::ACCESS_CONTROL_ALLOW_CREDENTIALS, 426 | HeaderValue::from_static("true"), 427 | ); 428 | } 429 | if let Some(expose_headers_header) = &self.exposed_headers { 430 | headers.typed_insert(expose_headers_header.clone()) 431 | } 432 | } 433 | } 434 | 435 | /// Convenience trait for converting a Url into an Origin for cors 436 | pub trait IntoOrigin { 437 | /// Convert a Url into an Origin for cors 438 | fn into_origin(self) -> Origin; 439 | } 440 | 441 | impl<'a> IntoOrigin for &'a str { 442 | fn into_origin(self) -> Origin { 443 | let mut parts = self.splitn(2, "://"); 444 | let scheme = parts.next().expect("missing scheme"); 445 | let rest = parts.next().expect("missing scheme"); 446 | 447 | Origin::try_from_parts(scheme, rest, None).expect("invalid Origin") 448 | } 449 | } 450 | --------------------------------------------------------------------------------