├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── resources │ └── index.html └── simple.rs ├── src ├── app.rs ├── endpoint.rs ├── error.rs ├── filter.rs ├── filter │ ├── log.rs │ └── session.rs ├── lib.rs ├── request.rs ├── responder.rs ├── response.rs ├── router.rs ├── state.rs ├── static_files.rs ├── test_client.rs ├── test_client │ ├── test_request.rs │ └── test_response.rs └── ws.rs └── tests └── test_client.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | .idea 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "highnoon" 3 | version = "0.0.9" 4 | authors = ["Steve Lee "] 5 | edition = "2021" 6 | 7 | description = "minimal web server framework inspired by tide, but built on hyper" 8 | license = "MIT" 9 | repository = "https://github.com/sphenlee/highnoon" 10 | documentation = "https://docs.rs/highnoon" 11 | readme = "README.md" 12 | 13 | keywords = ["web", "tokio", "hyper", "http"] 14 | categories = ["web-programming::http-server"] 15 | 16 | [dependencies] 17 | anyhow = "1.0.66" 18 | async-trait = "0.1.58" 19 | bytes = "1.2.1" 20 | cookie = { version = "0.16.1", features = ["signed"] } 21 | futures-util = "0.3.25" 22 | hyper = { version = "0.14.22", features = ["server", "http1", "http2", "runtime", "tcp", "stream"] } 23 | headers = "0.3.8" 24 | mime = "0.3.16" 25 | mime_guess = "2.0.4" 26 | route-recognizer = "0.3.1" 27 | serde = "1.0.147" 28 | serde_json = "1.0.87" 29 | serde_urlencoded = "0.7.1" 30 | time = "0.3.16" 31 | tokio = { version = "1.21.2", features = ["rt-multi-thread", "net", "macros", "io-util", "fs"] } 32 | tokio-tungstenite = "0.17.2" 33 | tokio-util = { version = "0.7.4", features = ["io"] } 34 | tracing = "0.1.37" 35 | uuid = { version = "1.2.1", features = ["v4"] } 36 | 37 | 38 | [dev-dependencies] 39 | serde_derive = "1.0.147" 40 | tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 Stephen Lee 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Highnoon 2 | ======== 3 | 4 | [![crates.io](https://img.shields.io/crates/v/highnoon.svg)](https://crates.io/crates/highnoon) 5 | [![API docs](https://docs.rs/highnoon/badge.svg)](https://docs.rs/highnoon) 6 | [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE) 7 | 8 | A minimal web framework built on Hyper 9 | 10 | **This is a very early development release. 11 | While I'm pretty happy with the API so far, anything could change.** 12 | 13 | To get started first implement the `State` trait which holds all data shared by 14 | all the route handlers. This trait contains a single method to get a new 15 | `Context` which is the data shared for the duration of a single request. 16 | `Context` is generally used for sharing data between filters. 17 | 18 | struct MyState; 19 | 20 | impl highnoon::State for MyState { 21 | type Context = (); 22 | 23 | fn new_context(&self) -> Context { 24 | () 25 | } 26 | } 27 | 28 | Then create an `App` with your `State`, attach filters and routes 29 | and launch the server. 30 | 31 | #[tokio::main] 32 | async fn main() -> highnoon::Result<()> { 33 | let mut app = highnoon::App::new(MyState); 34 | 35 | app.with(highnoon::filter::Log); 36 | 37 | app.at("/hello").get(|_request| async { 38 | "Hello world!\n\n" 39 | }); 40 | 41 | app.listen("0.0.0.0:8888").await?; 42 | Ok(()) 43 | } 44 | 45 | -------------------------------------------------------------------------------- /examples/resources/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Highnoon test 5 | 6 | 7 |

Highnoon test

8 | 9 | 10 | -------------------------------------------------------------------------------- /examples/simple.rs: -------------------------------------------------------------------------------- 1 | use headers::authorization::{Authorization, Bearer}; 2 | use highnoon::filter::session; 3 | use highnoon::filter::session::{HasSession, Session}; 4 | use highnoon::filter::Next; 5 | use highnoon::{App, Error, Json, Message, Request, Response, Result}; 6 | use hyper::StatusCode; 7 | use serde_derive::Serialize; 8 | use tokio; 9 | use tracing::info; 10 | 11 | /// a fake database, in a real server this would be a pool connection 12 | #[derive(Debug)] 13 | struct Db; 14 | 15 | impl Default for Db { 16 | fn default() -> Self { 17 | Db 18 | } 19 | } 20 | 21 | /// An extension trait to get access to the database 22 | trait HasDb { 23 | fn get_db(&self) -> &Db; 24 | } 25 | 26 | /// Application state 27 | #[derive(Default)] 28 | struct State { 29 | db: Db, 30 | } 31 | 32 | /// Per request context 33 | #[derive(Default)] 34 | struct Context { 35 | session: session::Session, 36 | } 37 | 38 | /// Implement state for our struct 39 | impl highnoon::State for State { 40 | type Context = Context; 41 | 42 | fn new_context(&self) -> Context { 43 | Context::default() 44 | } 45 | } 46 | 47 | /// Our context has sessions 48 | impl session::HasSession for Context { 49 | fn session(&mut self) -> &mut Session { 50 | &mut self.session 51 | } 52 | } 53 | 54 | /// Our state has a database 55 | impl HasDb for State { 56 | fn get_db(&self) -> &Db { 57 | &self.db 58 | } 59 | } 60 | 61 | /// We can also extend the Request for states that have a Db 62 | impl HasDb for Request 63 | where 64 | S: highnoon::State + HasDb, 65 | { 66 | fn get_db(&self) -> &Db { 67 | self.state().get_db() 68 | } 69 | } 70 | 71 | /// Data we store in the Session 72 | #[derive(Serialize)] 73 | struct Sample { 74 | data: String, 75 | value: i32, 76 | } 77 | 78 | #[derive(Default)] 79 | struct ApiState; 80 | 81 | #[derive(Default)] 82 | struct ApiContext { 83 | token: Option, 84 | } 85 | 86 | impl From for ApiContext { 87 | fn from(_: Context) -> Self { 88 | ApiContext::default() 89 | } 90 | } 91 | 92 | /// Implement state for our struct 93 | impl highnoon::State for ApiState { 94 | type Context = ApiContext; 95 | 96 | fn new_context(&self) -> ApiContext { 97 | ApiContext::default() 98 | } 99 | } 100 | 101 | /// A filter for checking token auth 102 | struct AuthCheck; 103 | 104 | #[async_trait::async_trait] 105 | impl highnoon::filter::Filter for AuthCheck { 106 | async fn apply( 107 | &self, 108 | mut req: Request, 109 | next: Next<'_, ApiState>, 110 | ) -> Result { 111 | let auth = req.header::>(); 112 | 113 | match auth { 114 | None => return Ok(Response::status(StatusCode::UNAUTHORIZED)), 115 | Some(bearer) => { 116 | info!("got bearer token: {}", bearer.0.token()); 117 | req.context_mut().token = Some(bearer.0.token().to_owned()); 118 | next.next(req).await 119 | } 120 | } 121 | } 122 | } 123 | 124 | /// A route handler that returns an Error which translates into HTTP bad request 125 | fn error_example(req: &Request) -> Result<()> { 126 | let fail = req.param("fail")?.parse::()?; 127 | 128 | if fail { 129 | Err(Error::bad_request("you asked for it")) 130 | } else { 131 | Ok(()) 132 | } 133 | } 134 | 135 | #[tokio::main] 136 | async fn main() -> Result<()> { 137 | tracing_subscriber::fmt().compact().init(); 138 | 139 | // create the root app 140 | let mut app = App::new(State::default()); 141 | 142 | // install the logging filter 143 | app.with(highnoon::filter::Log); 144 | 145 | // setup session handling 146 | let memstore = highnoon::filter::session::MemorySessionStore::new(); 147 | app.with( 148 | highnoon::filter::session::SessionFilter::new(memstore) 149 | .with_cookie_name("simple_sid") 150 | .with_expiry(time::Duration::minutes(5)) 151 | .with_callback(|cookie| { 152 | // for demo purposes - default is secure cookies 153 | cookie.set_secure(false); 154 | }), 155 | ); 156 | 157 | // setup routes 158 | // basic route to show get and post 159 | app.at("/hello") 160 | .get(|_req| async { "Hello world!\n\n" }) 161 | .post(|mut req: Request| async move { 162 | let bytes = req.body_bytes().await?; 163 | Ok(bytes) 164 | }); 165 | 166 | // a route with a parameter, also uses session data 167 | app.at("/echo/:name") 168 | .get(|mut req: Request| async move { 169 | let seen = match req.session().get("seen") { 170 | None => 0, 171 | Some(s) => s.parse()?, 172 | }; 173 | 174 | let greeting = if seen > 1 { 175 | "You again!" 176 | } else if seen == 1 { 177 | "Welcome back" 178 | } else { 179 | "Hello" 180 | }; 181 | 182 | req.session().set("seen".to_owned(), (seen + 1).to_string()); 183 | 184 | let p = req.param("name"); 185 | Ok(match p { 186 | Err(_) => format!("{} anonymous\n\n", greeting), 187 | Ok(name) => format!("{} {}\n\n", greeting, name), 188 | }) 189 | }); 190 | 191 | // route that accesses the "database" 192 | app.at("/db").get(|req: Request| async move { 193 | let db = req.get_db(); 194 | format!("database is {:?}", db) 195 | }); 196 | 197 | // return some json 198 | app.at("/json").get(|_req| async { 199 | Json(Sample { 200 | data: "hello".to_owned(), 201 | value: 1234, 202 | }) 203 | }); 204 | 205 | // demonstrate using Err to return HTTP errors 206 | app.at("/error/:fail").get(|req| async move { 207 | error_example(&req)?; 208 | Ok("") 209 | }); 210 | 211 | // use a function as a handler 212 | app.at("/query").get(echo_stuff); 213 | 214 | // websocket 215 | app.at("/ws/:name").ws(|req, mut tx, mut rx| async move { 216 | println!("running the websocket"); 217 | 218 | let name = req.param("name")?; 219 | 220 | while let Some(msg) = rx.recv().await? { 221 | println!("message: {}", msg); 222 | let reply = Message::text(format!("Hello {}, from Highnoon!", name)); 223 | tx.send(reply).await?; 224 | } 225 | 226 | println!("websocket closed"); 227 | Ok(()) 228 | }); 229 | 230 | // create a sub-app with the auth filter 231 | let mut api = App::new(ApiState::default()); 232 | api.with(AuthCheck); 233 | 234 | // check auth is working 235 | api.at("check").get(|req: Request| async move { 236 | println!("URI: {}", req.uri()); 237 | println!("Bearer: {:?}", req.context().token); 238 | StatusCode::OK 239 | }); 240 | // check that parameters get merged 241 | api.at("user/:name").get(|req: Request<_>| async move { 242 | println!("URI: {}", req.uri()); 243 | println!("params: {:?}", req.params()); 244 | StatusCode::OK 245 | }); 246 | 247 | // mount the sub-app into the root 248 | app.at("/api/:version").mount(api); 249 | 250 | // static files handling 251 | app.at("/static/*").static_files("examples/resources/"); 252 | 253 | // launch the server! 254 | app.listen("0.0.0.0:8888").await?; 255 | Ok(()) 256 | } 257 | 258 | /// demonstrate all the request methods 259 | async fn echo_stuff(mut req: Request) -> Result { 260 | let uri = req.uri(); 261 | println!("URI: {}", uri); 262 | 263 | let method = req.method(); 264 | println!("method: {}", method); 265 | 266 | let headers = req.headers(); 267 | println!("header: {:#?}", headers); 268 | 269 | let body = req.body_bytes().await?; 270 | println!("body: {}", String::from_utf8_lossy(&body)); 271 | 272 | println!("remote addr: {}", req.remote_addr()); 273 | 274 | Ok(StatusCode::OK) 275 | } 276 | -------------------------------------------------------------------------------- /src/app.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::Endpoint; 2 | use crate::filter::{Filter, Next}; 3 | use crate::router::{RouteTarget, Router}; 4 | use crate::state::State; 5 | use crate::static_files::StaticFiles; 6 | use crate::test_client::TestClient; 7 | use crate::ws::{WebSocketReceiver, WebSocketSender}; 8 | use crate::{Request, Responder, Response, Result}; 9 | use async_trait::async_trait; 10 | use hyper::server::conn::{AddrIncoming, AddrStream}; 11 | use hyper::server::Builder; 12 | use hyper::service::{make_service_fn, service_fn}; 13 | use hyper::{Body, Method}; 14 | use std::convert::Infallible; 15 | use std::future::Future; 16 | use std::net::SocketAddr; 17 | use std::path::PathBuf; 18 | use std::sync::Arc; 19 | use tokio::net::ToSocketAddrs; 20 | use tracing::info; 21 | 22 | /// The main entry point to highnoon. An `App` can be launched as a server 23 | /// or mounted into another `App`. 24 | /// Each `App` has a chain of [`Filters`](Filter) 25 | /// which are applied to each request. 26 | pub struct App { 27 | state: S, 28 | routes: Router, 29 | filters: Vec + Send + Sync + 'static>>, 30 | } 31 | 32 | /// Returned by [App::at] and attaches method handlers to a route. 33 | pub struct Route<'a, 'p, S: State> { 34 | path: &'p str, 35 | app: &'a mut App, 36 | } 37 | 38 | impl<'a, 'p, S: State> Route<'a, 'p, S> { 39 | /// Attach an endpoint for a specific HTTP method 40 | pub fn method(self, method: Method, ep: impl Endpoint + Send + Sync + 'static) -> Self { 41 | self.app.routes.add(method, self.path, ep); 42 | self 43 | } 44 | 45 | /// Attach an endpoint for all HTTP methods. These will be checked only if no 46 | /// specific endpoint exists for the method. 47 | pub fn all(self, ep: impl Endpoint + Send + Sync + 'static) -> Self { 48 | self.app.routes.add_all(self.path, ep); 49 | self 50 | } 51 | 52 | /// Attach an endpoint for GET requests 53 | pub fn get(self, ep: impl Endpoint + Send + Sync + 'static) -> Self { 54 | self.method(Method::GET, ep) 55 | } 56 | 57 | /// Attach an endpoint for POST requests 58 | pub fn post(self, ep: impl Endpoint + Send + Sync + 'static) -> Self { 59 | self.method(Method::POST, ep) 60 | } 61 | 62 | /// Attach an endpoint for PUT requests 63 | pub fn put(self, ep: impl Endpoint + Send + Sync + 'static) -> Self { 64 | self.method(Method::PUT, ep) 65 | } 66 | 67 | /// Attach an endpoint for DELETE requests 68 | pub fn delete(self, ep: impl Endpoint + Send + Sync + 'static) -> Self { 69 | self.method(Method::DELETE, ep) 70 | } 71 | 72 | /// Serve static files located in the path `root`. The path should end with a wildcard segment 73 | /// (ie. `/*`). The wildcard portion of the URL will be appended to `root` to form the full 74 | /// path. The file extension is used to guess a mime type. Files outside of `root` will return 75 | /// a FORBIDDEN error code; `..` and `.` path segments are allowed as long as they do not navigate 76 | /// outside of `root`. 77 | pub fn static_files(self, root: impl Into) -> Self { 78 | let prefix = self.path.to_owned(); // TODO - borrow issue here 79 | self.method(Method::GET, StaticFiles::new(root, prefix)) 80 | } 81 | 82 | /// Mount an app to handle all requests from this path. 83 | /// The path may contain parameters and these will be merged into 84 | /// the parameters from individual paths in the inner `App`. 85 | /// The App may have a different state type, but its `Context` must implement `From` to perform 86 | /// the conversion from the parent state's `Context` - *the inner `App`'s `new_context` won't 87 | /// be called*. 88 | pub fn mount(&mut self, app: App) 89 | where 90 | S2: State, 91 | S2::Context: From, 92 | { 93 | let path = self.path.to_owned() + "/*-highnoon-path-rest-"; 94 | let mounted = MountedApp { app: Arc::new(app) }; 95 | self.app.at(&path).all(mounted); 96 | } 97 | 98 | /// Attach a websocket handler to this route 99 | pub fn ws(self, handler: H) 100 | where 101 | H: Send + Sync + 'static + Fn(Request, WebSocketSender, WebSocketReceiver) -> F, 102 | F: Future> + Send + 'static, 103 | { 104 | self.method(Method::GET, crate::ws::endpoint(handler)); 105 | } 106 | } 107 | 108 | impl App { 109 | /// Create a new `App` with the given state. 110 | /// State must be `Send + Sync + 'static` because it gets shared by all route handlers. 111 | /// If you need inner mutability use a `Mutex` or similar. 112 | pub fn new(state: S) -> Self { 113 | Self { 114 | state, 115 | routes: Router::new(), 116 | filters: vec![], 117 | } 118 | } 119 | 120 | /// Create a test client by consuming this App. The test client can be used to send fake 121 | /// requests to the App and receive responses back. This can be used in unit and 122 | /// integration tests. 123 | pub fn test(self) -> TestClient { 124 | TestClient::new(self) 125 | } 126 | 127 | /// Get a reference to this App's state 128 | pub fn state(&self) -> &S { 129 | &self.state 130 | } 131 | 132 | /// Append a filter to the chain. Filters are applied to all endpoints in this app, and are 133 | /// applied in the order they are registered. 134 | pub fn with(&mut self, filter: F) 135 | where 136 | F: Filter + Send + Sync + 'static, 137 | { 138 | self.filters.push(Box::new(filter)); 139 | } 140 | 141 | /// Create a route at the given path. Returns a [Route] object on which you can 142 | /// attach handlers for each HTTP method 143 | pub fn at<'a, 'p>(&'a mut self, path: &'p str) -> Route<'a, 'p, S> { 144 | Route { path, app: self } 145 | } 146 | 147 | /// Start a server listening on the given address (See [ToSocketAddrs] from tokio) 148 | /// This method only returns if there is an error. (Graceful shutdown is TODO) 149 | pub async fn listen(self, host: impl ToSocketAddrs) -> anyhow::Result<()> { 150 | let mut addrs = tokio::net::lookup_host(host).await?; 151 | let addr = addrs 152 | .next() 153 | .ok_or_else(|| anyhow::Error::msg("host lookup returned no hosts"))?; 154 | 155 | let builder = hyper::Server::try_bind(&addr)?; 156 | self.internal_serve(builder).await 157 | } 158 | 159 | /// Start a server listening on the provided [std::net::TcpListener] 160 | /// This method only returns if there is an error. (Graceful shutdown is TODO) 161 | pub async fn listen_on(self, tcp: std::net::TcpListener) -> anyhow::Result<()> { 162 | let builder = hyper::Server::from_tcp(tcp)?; 163 | self.internal_serve(builder).await 164 | } 165 | 166 | async fn internal_serve(self, builder: Builder) -> anyhow::Result<()> { 167 | let app = Arc::new(self); 168 | 169 | let make_svc = make_service_fn(|addr_stream: &AddrStream| { 170 | let app = app.clone(); 171 | let addr = addr_stream.remote_addr(); 172 | 173 | async move { 174 | Ok::<_, Infallible>(service_fn(move |req: hyper::Request| { 175 | let app = app.clone(); 176 | async move { 177 | App::serve_one_req(app, req, addr) 178 | .await 179 | .map_err(|err| err.into_std()) 180 | } 181 | })) 182 | } 183 | }); 184 | 185 | let server = builder.serve(make_svc); 186 | info!("server listening on {}", server.local_addr()); 187 | server.await?; 188 | Ok(()) 189 | } 190 | 191 | pub(crate) async fn serve_one_req( 192 | app: Arc>, 193 | req: hyper::Request, 194 | addr: SocketAddr, 195 | ) -> Result> { 196 | let RouteTarget { ep, params } = app.routes.lookup(req.method(), req.uri().path()); 197 | 198 | let ctx = app.state.new_context(); 199 | let req = Request::new(app.clone(), req, params, addr, ctx); 200 | 201 | let next = Next { 202 | ep, 203 | rest: &*app.filters, 204 | }; 205 | 206 | next.next(req) 207 | .await 208 | .or_else(|err| err.into_response()) 209 | .map(|resp| resp.into_inner()) 210 | } 211 | } 212 | 213 | struct MountedApp { 214 | app: Arc>, 215 | } 216 | 217 | #[async_trait] 218 | impl Endpoint for MountedApp 219 | where 220 | S2::Context: From, 221 | { 222 | async fn call(&self, req: Request) -> Result { 223 | // deconstruct the request from the outer state 224 | let (inner, params, remote_addr, context) = req.into_parts(); 225 | // get the part of the path still to be routed 226 | let path_rest = params 227 | .find("-highnoon-path-rest-") 228 | .expect("-highnoon-path-rest- is missing!"); 229 | // lookup the target for the request in the nested app 230 | let RouteTarget { 231 | ep, 232 | params: params2, 233 | } = self.app.routes.lookup(inner.method(), path_rest); 234 | 235 | // construct a new request for the inner state type 236 | let mut req2 = Request::new(self.app.clone(), inner, params, remote_addr, context.into()); 237 | 238 | // merge the inner params 239 | req2.merge_params(params2); 240 | 241 | // start the filter chain for the nested app 242 | let next = Next { 243 | ep, 244 | rest: &*self.app.filters, 245 | }; 246 | 247 | next.next(req2).await 248 | } 249 | } 250 | -------------------------------------------------------------------------------- /src/endpoint.rs: -------------------------------------------------------------------------------- 1 | use crate::state::State; 2 | /// Exposes the `Endpoint` trait if you want to implement it for custom types. 3 | /// 4 | /// This is not usually necessary since it's implemented for function types already. 5 | use crate::{Request, Responder, Response, Result}; 6 | use async_trait::async_trait; 7 | use std::future::Future; 8 | 9 | /// Implement `Endpoint` for a type to be used as a method handler. 10 | /// 11 | /// It is already implemented for functions of `Request` to `Result` 12 | /// which is the simplest (and most convenient) kind of handler. 13 | /// You can implement it manually for endpoints that may require some kind of local state. 14 | /// 15 | /// `Endpoint` uses the `#[async_trait]` attribute hence the signature presented in the docs here 16 | /// has been modified. An example of implementing using the attribute: 17 | /// ```rust 18 | /// # use highnoon::{Endpoint, State, Result, Request, Response}; 19 | /// struct NoOpEndpoint; 20 | /// 21 | /// #[async_trait::async_trait] 22 | /// impl Endpoint for NoOpEndpoint 23 | /// { 24 | /// async fn call(&self, req: Request) -> Result { 25 | /// Ok(Response::ok()) 26 | /// } 27 | /// } 28 | /// ``` 29 | #[async_trait] 30 | pub trait Endpoint { 31 | async fn call(&self, req: Request) -> Result; 32 | } 33 | 34 | #[async_trait] 35 | impl Endpoint for F 36 | where 37 | F: Send + Sync + 'static + Fn(Request) -> Fut, 38 | Fut: Future + Send + 'static, 39 | R: Responder + 'static, 40 | S: State, 41 | { 42 | async fn call(&self, req: Request) -> Result { 43 | (self)(req).await.into_response() 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use crate::{Responder, Response, Result}; 2 | use hyper::StatusCode; 3 | use std::error::Error as StdError; 4 | use std::fmt::Formatter; 5 | 6 | /// Error type expected to be returned by endpoints. 7 | /// 8 | /// It can represent an HTTP level error which is useful for helper functions 9 | /// that wish to cause an early return from a handler (using the question mark operator). 10 | /// It can also represent any other kind of error (using the `anyhow::Error` type). These 11 | /// errors are logged (if you enable the logging filter) and converted to a 500 Internal Server Error 12 | /// with no other details. 13 | /// 14 | /// HTTP level error should be created with the `http` methods (which accepts a `Responder` rather than 15 | /// just `Response`) and Internal errors should be created with the `From`/`Into` implementation. 16 | pub enum Error { 17 | /// An error that should get returned to the client 18 | Http(Response), 19 | /// Internal errors, reported as 500 Internal Server Error and logged locally 20 | Internal(anyhow::Error), 21 | } 22 | 23 | impl Error { 24 | /// Convert this error into a boxed std::error::Error 25 | pub(crate) fn into_std(self) -> Box { 26 | match self { 27 | Error::Http(_) => panic!("http error??!"), 28 | Error::Internal(err) => err.into(), 29 | } 30 | } 31 | 32 | /// Create an Error from a `Responder` - the `Responder` will be converted to a response 33 | /// and returned to the HTTP Client exactly the same way as an `Result::Ok` would be. 34 | /// This is useful in conjunction with the `?` operator for early returns. 35 | pub fn http(resp: impl Responder) -> Self { 36 | match resp.into_response() { 37 | Ok(r) => Self::Http(r), 38 | Err(e) => e, 39 | } 40 | } 41 | 42 | /// Create a 400 Bad Request Error from a `Responder` - this method is similar to [Error::http] 43 | /// but it also sets the status code 44 | pub fn bad_request(resp: impl Responder) -> Self { 45 | Self::http((StatusCode::BAD_REQUEST, resp)) 46 | } 47 | } 48 | 49 | impl Responder for Error { 50 | fn into_response(self) -> Result { 51 | match self { 52 | Error::Http(resp) => Ok(resp), 53 | Error::Internal(_err) => { 54 | //log::error!("internal server error: {}", err); 55 | Ok(Response::status(StatusCode::INTERNAL_SERVER_ERROR)) 56 | } 57 | } 58 | } 59 | } 60 | 61 | impl From for Error 62 | where 63 | //E: std::error::Error + Send + Sync + 'static, 64 | E: Into, 65 | { 66 | fn from(e: E) -> Self { 67 | //Error::Internal(anyhow::Error::new(e)) 68 | Error::Internal(e.into()) 69 | } 70 | } 71 | 72 | impl std::fmt::Debug for Error { 73 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 74 | match self { 75 | Error::Internal(err) => f 76 | .debug_struct("Error::Internal") 77 | .field("inner", err) 78 | .finish(), 79 | Error::Http(resp) => f.debug_struct("Error::Http").field("inner", resp).finish(), 80 | } 81 | } 82 | } 83 | 84 | impl std::fmt::Display for Error { 85 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 86 | match self { 87 | Error::Internal(err) => write!(f, "Internal Error: {:?}", err), 88 | Error::Http(resp) => write!(f, "{:?}", resp), 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/filter.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::Endpoint; 2 | /// Filters are reusable bits of logic that wrap endpoints. 3 | /// 4 | /// (These are sometimes called "middleware" in other frameworks). 5 | use crate::{Request, Response, Result, State}; 6 | use async_trait::async_trait; 7 | use std::future::Future; 8 | 9 | mod log; 10 | pub mod session; // TODO - export the needed bits of this 11 | 12 | pub use self::log::Log; 13 | 14 | /// Represents either the next Filter in the chain, or the actual endpoint if the chain is 15 | /// empty or completed. Use its `next` method to call the next filter/endpoint if the 16 | /// request should continue to be processed. 17 | pub struct Next<'a, S> 18 | where 19 | S: Send + Sync + 'static, 20 | { 21 | pub(crate) ep: &'a (dyn Endpoint + Send + Sync), 22 | pub(crate) rest: &'a [Box + Send + Sync + 'static>], 23 | } 24 | 25 | impl Next<'_, S> { 26 | /// Call either the next filter in the chain, or the actual endpoint if there are no more 27 | /// filters. Filters are not required to call next (eg. to return a Forbidden status instead) 28 | pub async fn next(self, req: Request) -> Result { 29 | match self.rest.split_first() { 30 | Some((head, rest)) => { 31 | let next = Next { ep: self.ep, rest }; 32 | head.apply(req, next).await 33 | } 34 | None => self.ep.call(req).await, 35 | } 36 | } 37 | } 38 | 39 | /// A Filter is a reusable bit of logic which wraps an endpoint to provide pre- and post-processing. 40 | /// Filters can call the `Next` argument to continue processing, or may return early to stop the 41 | /// chain. Filters can be used for logging, authentication, cookie handling and many other uses. 42 | /// 43 | /// `Filter` uses the `#[async_trait]` attribute hence the signature presented in the docs here has 44 | /// been modified. An example of implementing using the attribute: 45 | /// ```rust 46 | /// # use highnoon::{filter::{Filter, Next}, State, Result, Request, Response}; 47 | /// struct NoOpFilter; 48 | /// 49 | /// #[async_trait::async_trait] 50 | /// impl Filter for NoOpFilter 51 | /// { 52 | /// async fn apply(&self, req: Request, next: Next<'_, S>) -> Result { 53 | /// next.next(req).await 54 | /// } 55 | /// } 56 | /// ``` 57 | #[async_trait] 58 | pub trait Filter { 59 | async fn apply(&self, req: Request, next: Next<'_, S>) -> Result; 60 | } 61 | 62 | // implement for async functions 63 | #[async_trait] 64 | impl Filter for F 65 | where 66 | S: State, 67 | F: Send + Sync + 'static + for<'n> Fn(Request, Next<'n, S>) -> Fut, 68 | Fut: Send + 'static + Future>, 69 | { 70 | async fn apply(&self, req: Request, next: Next<'_, S>) -> Result { 71 | self(req, next).await 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/filter/log.rs: -------------------------------------------------------------------------------- 1 | use crate::filter::{Filter, Next}; 2 | use crate::{Error, Request, Response, Result}; 3 | use async_trait::async_trait; 4 | 5 | use crate::state::State; 6 | use tracing::{debug, error, info, warn}; 7 | 8 | /// A logging filter. Logs all requests at debug level, and logs responses at error, warn or info 9 | /// level depending on the status code (5xx, 4xx, and everything else) 10 | pub struct Log; 11 | 12 | fn log_response(method: String, uri: String, resp: &Response) { 13 | let status = resp.as_ref().status(); 14 | if status.is_server_error() { 15 | error!(%method, %uri, %status, "response"); 16 | } else if status.is_client_error() { 17 | warn!(%method, %uri, %status, "response"); 18 | } else { 19 | info!(%method, %uri, %status, "response"); 20 | } 21 | } 22 | 23 | #[async_trait] 24 | impl Filter for Log { 25 | async fn apply(&self, req: Request, next: Next<'_, S>) -> Result { 26 | let method = req.method().to_string(); 27 | let uri = req.uri().to_string(); 28 | 29 | debug!(%method, %uri, "request"); 30 | 31 | let result = next.next(req).await; 32 | 33 | match &result { 34 | Ok(resp) => log_response(method, uri, resp), 35 | Err(Error::Http(resp)) => log_response(method, uri, resp), 36 | Err(Error::Internal(err)) => { 37 | error!(%method, 38 | %uri, 39 | error=%err, 40 | backtrace=?err, 41 | "internal server error" 42 | ); 43 | } 44 | } 45 | 46 | result 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/filter/session.rs: -------------------------------------------------------------------------------- 1 | use crate::filter::{Filter, Next}; 2 | use crate::{Request, Response, Result}; 3 | 4 | use crate::state::State; 5 | use async_trait::async_trait; 6 | use cookie::Cookie; 7 | use headers::{Header, SetCookie}; 8 | use std::borrow::Cow; 9 | use std::collections::HashMap; 10 | use std::sync::atomic::{AtomicBool, Ordering}; 11 | use std::sync::Arc; 12 | use std::sync::Mutex; 13 | use tokio::sync::Mutex as AsyncMutex; 14 | use tracing::debug; 15 | use uuid::Uuid; 16 | 17 | /// Trait for session storage 18 | #[async_trait] 19 | pub trait SessionStore { 20 | /// Get the data associated with session 21 | async fn get(&self, id: &str) -> Result>; 22 | /// Set the data for a session 23 | async fn set(&mut self, id: String, value: String) -> Result<()>; 24 | /// Clear data for a session 25 | async fn clear(&mut self, id: &str) -> Result<()>; 26 | } 27 | 28 | /// Memory backed implementation of session storage. 29 | /// NOTE this is only meant for demos and examples. In a real server 30 | /// you would store sessions externally (e.g. in redis or a database) 31 | #[derive(Default)] 32 | pub struct MemorySessionStore { 33 | data: HashMap, 34 | } 35 | 36 | impl MemorySessionStore { 37 | /// Create a new memory session store 38 | pub fn new() -> Self { 39 | Self::default() 40 | } 41 | } 42 | 43 | #[async_trait] 44 | impl SessionStore for MemorySessionStore { 45 | async fn get(&self, id: &str) -> Result> { 46 | debug!(id, "memory store get"); 47 | Ok(self.data.get(id).cloned()) 48 | } 49 | 50 | async fn set(&mut self, id: String, value: String) -> Result<()> { 51 | debug!(%id, %value, "memory store set"); 52 | self.data.insert(id, value); 53 | Ok(()) 54 | } 55 | 56 | async fn clear(&mut self, id: &str) -> Result<()> { 57 | debug!(id, "memory store clear"); 58 | self.data.remove(id); 59 | Ok(()) 60 | } 61 | } 62 | 63 | pub const DEFAULT_COOKIE_NAME: &str = "sid"; 64 | 65 | type DynCookieCallback = dyn Fn(&mut Cookie) + Send + Sync + 'static; 66 | 67 | /// A filter for implementing basic session support 68 | /// 69 | /// This filter requires that the Context implements HasSession 70 | pub struct SessionFilter { 71 | cookie_name: Cow<'static, str>, 72 | expiry: time::Duration, 73 | cookie_callback: Option>, 74 | store: AsyncMutex>, 75 | } 76 | 77 | impl SessionFilter { 78 | /// Create a new session filter using the provided store 79 | /// The default cookie name is [DEFAULT_COOKIE_NAME] and expiry is set to one hour 80 | pub fn new(store: impl SessionStore + Send + Sync + 'static) -> SessionFilter { 81 | SessionFilter { 82 | cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME), 83 | expiry: time::Duration::hours(1), 84 | cookie_callback: None, 85 | store: AsyncMutex::new(Box::new(store)), 86 | } 87 | } 88 | 89 | /// Set the name of the cookie used to store the session ID 90 | pub fn with_cookie_name(mut self, name: impl Into>) -> Self { 91 | self.cookie_name = name.into(); 92 | self 93 | } 94 | 95 | /// Set the expiry time set on the session ID cookie 96 | pub fn with_expiry(mut self, expiry: time::Duration) -> Self { 97 | self.expiry = expiry; 98 | self 99 | } 100 | 101 | /// Set a callback function to be used to customise the session ID cookie. 102 | /// The callback is called with the cookie before it is stored in the headers so you can change 103 | /// most settings (changing the name or value of the cookie may prevent sessions from working, 104 | /// so only change settings like same site, secure, etc...) 105 | pub fn with_callback(mut self, callback: F) -> Self 106 | where 107 | F: Fn(&mut Cookie) + Send + Sync + 'static, 108 | { 109 | self.cookie_callback = Some(Box::new(callback)); 110 | self 111 | } 112 | } 113 | 114 | #[derive(Default)] 115 | struct SessionInner { 116 | modified: AtomicBool, 117 | data: Mutex>, 118 | } 119 | 120 | /// A session 121 | #[derive(Default)] 122 | pub struct Session { 123 | inner: Arc, 124 | } 125 | 126 | impl SessionInner { 127 | fn get(&self, key: &str) -> Option { 128 | debug!(key, "session get"); 129 | let data = self.data.lock().unwrap(); 130 | data.get(key).cloned() 131 | } 132 | 133 | fn set(&self, key: String, value: String) { 134 | debug!(%key, %value, "session set"); 135 | self.data.lock().unwrap().insert(key, value); 136 | self.modified.store(true, Ordering::Relaxed); 137 | } 138 | 139 | fn is_modified(&self) -> bool { 140 | self.modified.load(Ordering::Relaxed) 141 | } 142 | 143 | fn load(&self, data: HashMap) { 144 | *self.data.lock().unwrap() = data; 145 | 146 | // we just loaded fresh data into the session, so clear modified flag to 147 | // detect if any changes are made that need to be saved back to storage 148 | self.modified.store(false, Ordering::Relaxed); 149 | } 150 | } 151 | 152 | impl Session { 153 | /// Get a value from the session 154 | pub fn get(&self, key: &str) -> Option { 155 | self.inner.get(key) 156 | } 157 | 158 | /// Store a value into the session 159 | pub fn set(&self, key: String, value: String) { 160 | self.inner.set(key, value) 161 | } 162 | 163 | /// Determine if the session has been modified 164 | pub fn is_modified(&self) -> bool { 165 | self.inner.is_modified() 166 | } 167 | } 168 | 169 | /// This trait must be implemented by the Context type in order to use the 170 | /// SessionFilter 171 | pub trait HasSession { 172 | /// Get a reference to the Session for this current request 173 | fn session(&mut self) -> &mut Session; 174 | } 175 | 176 | /// Implement HasSession on requests where the Context has sessions 177 | impl HasSession for Request 178 | where 179 | S: State, 180 | S::Context: HasSession, 181 | { 182 | fn session(&mut self) -> &mut Session { 183 | self.context_mut().session() 184 | } 185 | } 186 | 187 | #[async_trait] 188 | impl Filter for SessionFilter 189 | where 190 | S: State, 191 | S::Context: HasSession, 192 | { 193 | async fn apply(&self, mut req: Request, next: Next<'_, S>) -> Result { 194 | let session = Arc::clone(&req.session().inner); 195 | 196 | let maybe_sid = req 197 | .cookies()? 198 | .get(self.cookie_name.as_ref()) 199 | .map(|c| c.value().to_owned()); 200 | 201 | let sid = if let Some(sid) = maybe_sid { 202 | debug!(%sid, "request has session cookie"); 203 | 204 | let store = self.store.lock().await; 205 | let raw_data = store.get(&sid).await?.unwrap_or_default(); 206 | let data = serde_urlencoded::from_str(&raw_data)?; 207 | session.load(data); 208 | sid 209 | } else { 210 | debug!("request has no session cookie"); 211 | Uuid::new_v4().to_string() 212 | }; 213 | 214 | let mut resp = next.next(req).await?; 215 | 216 | if session.is_modified() { 217 | debug!("session was modified"); 218 | 219 | let mut store = self.store.lock().await; 220 | let raw_data = { 221 | let data = session.data.lock().unwrap(); 222 | serde_urlencoded::to_string(&*data)? 223 | }; 224 | 225 | let mut cookie = Cookie::new(self.cookie_name.as_ref(), &sid); 226 | cookie.set_http_only(true); 227 | cookie.set_secure(true); 228 | cookie.set_same_site(cookie::SameSite::Strict); 229 | 230 | let expiry = time::OffsetDateTime::now_utc() + self.expiry; 231 | cookie.set_expires(expiry); 232 | 233 | if let Some(ref callback) = self.cookie_callback { 234 | callback(&mut cookie); 235 | } 236 | 237 | resp.set_raw_header(SetCookie::name(), cookie.to_string())?; 238 | 239 | store.set(sid, raw_data).await?; 240 | } 241 | 242 | Ok(resp) 243 | } 244 | } 245 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub use headers; 2 | pub use hyper::{Method, StatusCode}; 3 | pub use mime::Mime; 4 | pub use tokio_tungstenite::tungstenite::Message; 5 | 6 | mod app; 7 | mod endpoint; 8 | mod error; 9 | pub mod filter; 10 | mod request; 11 | mod responder; 12 | mod response; 13 | mod router; 14 | mod state; 15 | mod static_files; 16 | mod test_client; 17 | pub mod ws; 18 | 19 | pub use app::{App, Route}; 20 | pub use endpoint::Endpoint; 21 | pub use error::Error; 22 | pub use request::Request; 23 | pub use responder::{Form, Json, Responder}; 24 | pub use response::Response; 25 | pub use state::State; 26 | 27 | pub type Result = std::result::Result; 28 | -------------------------------------------------------------------------------- /src/request.rs: -------------------------------------------------------------------------------- 1 | use crate::state::State; 2 | use crate::{App, Error, Result}; 3 | use cookie::{Cookie, CookieJar}; 4 | use headers::{Header, HeaderMapExt}; 5 | use hyper::header::HeaderValue; 6 | use hyper::{body::Buf, Body, HeaderMap, StatusCode}; 7 | use route_recognizer::Params; 8 | use serde::de::DeserializeOwned; 9 | use std::io::Read; 10 | use std::net::SocketAddr; 11 | use std::sync::Arc; 12 | use tracing::error; 13 | 14 | /// An incoming request 15 | pub struct Request { 16 | app: Arc>, 17 | context: S::Context, 18 | params: Params, 19 | inner: hyper::Request, 20 | remote_addr: SocketAddr, 21 | } 22 | 23 | impl Request { 24 | pub(crate) fn new( 25 | app: Arc>, 26 | inner: hyper::Request, 27 | params: Params, 28 | remote_addr: SocketAddr, 29 | context: S::Context, 30 | ) -> Self { 31 | Self { 32 | app, 33 | context, 34 | inner, 35 | params, 36 | remote_addr, 37 | } 38 | } 39 | 40 | pub(crate) fn into_parts(self) -> (hyper::Request, Params, SocketAddr, S::Context) { 41 | (self.inner, self.params, self.remote_addr, self.context) 42 | } 43 | 44 | pub(crate) fn merge_params(&mut self, params: Params) { 45 | for (k, v) in params.iter() { 46 | self.params.insert(k.to_owned(), v.to_owned()); 47 | } 48 | } 49 | 50 | /// Get a reference to the App's state 51 | pub fn state(&self) -> &S { 52 | self.app.state() 53 | } 54 | 55 | /// Get a reference to the request's context 56 | pub fn context(&self) -> &S::Context { 57 | &self.context 58 | } 59 | 60 | /// Get a mut reference to the request's context 61 | pub fn context_mut(&mut self) -> &mut S::Context { 62 | &mut self.context 63 | } 64 | 65 | /// Get the HTTP method being used by this request 66 | pub fn method(&self) -> &hyper::Method { 67 | self.inner.method() 68 | } 69 | 70 | /// Get the URI that was used for this request 71 | pub fn uri(&self) -> &hyper::Uri { 72 | self.inner.uri() 73 | } 74 | 75 | /// Parse the URI query string into an instance of `T` that derives `Deserialize`. 76 | /// 77 | /// (To get the raw query string access it via `req.uri().query()`). 78 | /// If there is no query string, deserialize an empty string. 79 | pub fn query(&self) -> Result { 80 | // if there is no query string we can default to empty string 81 | // serde_urlencode will work if T has all optional fields 82 | let q = self.inner.uri().query().unwrap_or(""); 83 | let t = serde_urlencoded::from_str::(q) 84 | .map_err(|err| Error::bad_request(format!("invalid query parameter: {}", err)))?; 85 | Ok(t) 86 | } 87 | 88 | /// Get a typed header from the request 89 | /// (See also `headers`) 90 | pub fn header(&self) -> Option { 91 | self.inner.headers().typed_get() 92 | } 93 | 94 | /// Get all headers as a `HeaderMap` 95 | pub fn headers(&self) -> &HeaderMap { 96 | self.inner.headers() 97 | } 98 | 99 | /// Get the request's cookies 100 | pub fn cookies(&self) -> Result { 101 | let mut cookies = CookieJar::new(); 102 | 103 | for val in self.inner.headers().get_all(headers::Cookie::name()) { 104 | let c = Cookie::parse(val.to_str()?)?; 105 | cookies.add(c.into_owned()); 106 | } 107 | 108 | Ok(cookies) 109 | } 110 | 111 | /// Get a route parameter (eg. `:key` or `*key` segments in the URI path) 112 | /// 113 | /// If the parameter is not present, logs an error and returns a `400 Bad Request` to the client 114 | pub fn param(&self, param: &str) -> Result<&str> { 115 | self.params.find(param).ok_or_else(|| { 116 | error!("parameter {} not found", param); 117 | Error::http(StatusCode::BAD_REQUEST) 118 | }) 119 | } 120 | 121 | /// Get all route parameters 122 | pub fn params(&self) -> &Params { 123 | &self.params 124 | } 125 | 126 | /// Get the request body as a `hyper::Body` 127 | pub async fn body_mut(&mut self) -> Result<&mut Body> { 128 | Ok(self.inner.body_mut()) 129 | } 130 | 131 | pub(crate) fn as_inner_mut(&mut self) -> &mut hyper::Request { 132 | &mut self.inner 133 | } 134 | 135 | /// Get a reader to read the request body 136 | /// 137 | /// (This does buffer the whole body into memory, but not necessarily contiguous memory). 138 | /// If you need to protect against malicious clients you should access the body via `body_mut` 139 | pub async fn reader(&mut self) -> Result { 140 | let buffer = hyper::body::aggregate(self.inner.body_mut()).await?; 141 | Ok(buffer.reader()) 142 | } 143 | 144 | /// Get the request body as raw bytes in a `Vec` 145 | pub async fn body_bytes(&mut self) -> Result> { 146 | let bytes = hyper::body::to_bytes(self.inner.body_mut()).await?; 147 | Ok(bytes.to_vec()) 148 | } 149 | 150 | /// Get the request body as UTF-8 data in String 151 | pub async fn body_string(&mut self) -> Result { 152 | let bytes = hyper::body::to_bytes(self.inner.body_mut()).await?; 153 | Ok(String::from_utf8(bytes.to_vec())?) 154 | } 155 | 156 | /// Get the request body as JSON and deserialize into `T`. 157 | /// 158 | /// If deserialization fails, log an error and return `400 Bad Request`. 159 | /// (If this logic is not appropriate, consider using `reader` and using `serde_json` directly) 160 | pub async fn body_json(&mut self) -> Result { 161 | let reader = self.reader().await?; 162 | serde_json::from_reader(reader).map_err(|err| { 163 | let msg = format!("error parsing request body as json: {}", err); 164 | error!("{}", msg); 165 | Error::http((StatusCode::BAD_REQUEST, msg)) 166 | }) 167 | } 168 | 169 | /// Get the address of the remote peer. 170 | /// 171 | /// This method uses the network level address only and hence may be incorrect if you are 172 | /// behind a proxy. (This does *not* check for any `Forwarded` headers etc...) 173 | pub fn remote_addr(&self) -> &SocketAddr { 174 | &self.remote_addr 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /src/responder.rs: -------------------------------------------------------------------------------- 1 | use crate::response::Response; 2 | use crate::Result; 3 | use hyper::{Body, StatusCode}; 4 | use serde::Serialize; 5 | 6 | /// This trait is implemented for all the common types you can return from an endpoint 7 | /// 8 | /// It's also implemented for `Response` and `hyper::Response` for compatibility. 9 | /// There is an implementation for `Result where R: Responder` which allows fallible 10 | /// functions to be used as endpoints 11 | /// 12 | /// ``` 13 | /// use highnoon::{Request, Responder, Json, StatusCode}; 14 | /// 15 | /// fn example_1(_: Request<()>) -> impl Responder { 16 | /// // return status code 17 | /// StatusCode::NOT_FOUND 18 | /// } 19 | /// 20 | /// fn example_2(_: Request<()>) -> impl Responder { 21 | /// // return strings (&str or String) 22 | /// "Hello World" 23 | /// } 24 | /// 25 | /// fn example_3(_: Request<()>) -> impl Responder { 26 | /// // return status code with data 27 | /// (StatusCode::NOT_FOUND, "Not found!") 28 | /// } 29 | /// 30 | /// fn example_4(_: Request<()>) -> impl Responder { 31 | /// // return JSON data - for any type implementing `serde::Serialize` 32 | /// Json(vec![1, 2, 3]) 33 | /// } 34 | /// 35 | /// fn example_5(_: Request<()>) -> highnoon::Result { 36 | /// // fallible functions too 37 | /// // (also works the return type as `impl Responder` as long as Rust can infer 38 | /// // the function returns `highnoon::Result`) 39 | /// Ok((StatusCode::CONFLICT, "Already Exists")) 40 | /// } 41 | /// ``` 42 | 43 | pub trait Responder { 44 | fn into_response(self) -> Result; 45 | } 46 | 47 | impl Responder for StatusCode { 48 | fn into_response(self) -> Result { 49 | Ok(Response::status(self)) 50 | } 51 | } 52 | 53 | impl Responder for String { 54 | fn into_response(self) -> Result { 55 | Ok(Response::ok().body(self)) 56 | } 57 | } 58 | 59 | impl Responder for &str { 60 | fn into_response(self) -> Result { 61 | Ok(Response::ok().body(self.to_owned())) 62 | } 63 | } 64 | 65 | impl Responder for &[u8] { 66 | fn into_response(self) -> Result { 67 | Ok(Response::ok().body(self.to_vec())) 68 | } 69 | } 70 | 71 | impl Responder for Vec { 72 | fn into_response(self) -> Result { 73 | Ok(Response::ok().body(self)) 74 | } 75 | } 76 | 77 | impl Responder for (StatusCode, R) { 78 | fn into_response(self) -> Result { 79 | let mut resp = self.1.into_response()?; 80 | resp.set_status(self.0); 81 | Ok(resp) 82 | } 83 | } 84 | 85 | /// Returns `StatusCode::NotFound` for `None`, and the inner value for `Some` 86 | impl Responder for Option { 87 | fn into_response(self) -> Result { 88 | match self { 89 | None => StatusCode::NOT_FOUND.into_response(), 90 | Some(r) => r.into_response(), 91 | } 92 | } 93 | } 94 | 95 | /// A Wrapper to return a JSON payload. This can be wrapped over any `serde::Serialize` type. 96 | /// ``` 97 | /// use highnoon::{Request, Responder, Json}; 98 | /// fn returns_json(_: Request<()>) -> impl Responder { 99 | /// Json(vec!["an", "array"]) 100 | /// } 101 | /// ``` 102 | pub struct Json(pub T); 103 | 104 | impl Responder for Json { 105 | fn into_response(self) -> Result { 106 | Response::ok().json(self.0) 107 | } 108 | } 109 | 110 | /// A Wrapper to return Form data. This can be wrapped over any `serde::Serialize` type. 111 | pub struct Form(pub T); 112 | 113 | impl Responder for Form { 114 | fn into_response(self) -> Result { 115 | Response::ok().form(self.0) 116 | } 117 | } 118 | 119 | /// Identity implementation 120 | impl Responder for Response { 121 | fn into_response(self) -> Result { 122 | Ok(self) 123 | } 124 | } 125 | 126 | /// Compatibility with the inner hyper::Response 127 | impl Responder for hyper::Response { 128 | fn into_response(self) -> Result { 129 | Ok(self.into()) 130 | } 131 | } 132 | 133 | impl Responder for Result { 134 | fn into_response(self) -> Result { 135 | self.and_then(|r| r.into_response()) 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/response.rs: -------------------------------------------------------------------------------- 1 | /// A wrapper over `hyper::Response` with better ergonomics 2 | /// 3 | /// ``` 4 | /// use highnoon::{Request, Responder, Response}; 5 | /// fn example(_: Request<()>) -> impl Responder { 6 | /// Response::ok().json(vec![1, 2, 3]) 7 | /// } 8 | /// ``` 9 | use crate::Result; 10 | use headers::{Header, HeaderMapExt}; 11 | use hyper::header::{HeaderName, HeaderValue}; 12 | use hyper::{Body, StatusCode}; 13 | use serde::Serialize; 14 | use std::convert::TryInto; 15 | use std::path::Path; 16 | use tokio::io::AsyncRead; 17 | use tokio_util::io::ReaderStream; 18 | use tracing::debug; 19 | 20 | /// A response to be returned to the client. 21 | /// You do not always need to use this struct directly as endpoints can 22 | /// return anything implementing `Responder`. However this is the most flexible 23 | /// way to construct a reply, and it implements `Responder` (the "identity" implementation). 24 | #[derive(Debug)] 25 | pub struct Response { 26 | inner: hyper::Response, 27 | } 28 | 29 | impl Response { 30 | /// Create an empty response with status code OK (200) 31 | pub fn ok() -> Self { 32 | Self { 33 | inner: hyper::Response::builder() 34 | .status(StatusCode::OK) 35 | .body(Body::empty()) 36 | .expect("ok status with empty body should never fail"), 37 | } 38 | } 39 | 40 | /// Create an empty response with the given status code 41 | pub fn status(s: StatusCode) -> Self { 42 | Self { 43 | inner: hyper::Response::builder() 44 | .status(s) 45 | .body(Body::empty()) 46 | .expect("status with empty body should never fail"), 47 | } 48 | } 49 | 50 | /// Set the status code of a response 51 | pub fn set_status(&mut self, s: StatusCode) { 52 | *self.inner.status_mut() = s; 53 | } 54 | 55 | /// Get the status code of a response 56 | pub fn get_status(&self) -> StatusCode { 57 | self.inner.status() 58 | } 59 | 60 | /// Set the body of the response 61 | pub fn body(mut self, body: impl Into) -> Self { 62 | *self.inner.body_mut() = body.into(); 63 | self 64 | } 65 | 66 | /// Set the body to an AsyncRead object 67 | pub fn reader(mut self, r: impl AsyncRead + Send + 'static) -> Self { 68 | let body = Body::wrap_stream(ReaderStream::new(r)); 69 | *self.inner.body_mut() = body; 70 | self 71 | } 72 | 73 | /// Set the body to the content of a file given by a Path 74 | /// Also sets a content type by guessing the mime type from the path name 75 | pub async fn path(self, path: impl AsRef) -> Result { 76 | let target = path.as_ref(); 77 | 78 | let reader = tokio::fs::File::open(&target).await?; 79 | 80 | let mime = mime_guess::from_path(&target).first_or_text_plain(); 81 | debug!("guessed mime: {}", mime); 82 | 83 | Ok(self.header(headers::ContentType::from(mime)).reader(reader)) 84 | } 85 | 86 | /// Set the body of the response to a JSON payload 87 | pub fn json(mut self, body: impl Serialize) -> Result { 88 | let data = serde_json::to_vec(&body)?; 89 | self.set_header(headers::ContentType::json()); 90 | *self.inner.body_mut() = Body::from(data); 91 | Ok(self) 92 | } 93 | 94 | /// Set the body of the response to form data 95 | pub fn form(mut self, body: impl Serialize) -> Result { 96 | let form = serde_urlencoded::to_string(body)?; 97 | self.set_header(headers::ContentType::form_url_encoded()); 98 | *self.inner.body_mut() = Body::from(form); 99 | Ok(self) 100 | } 101 | 102 | /// Set a header (from the `headers` crate) 103 | pub fn header(mut self, h: H) -> Self { 104 | self.set_header(h); 105 | self 106 | } 107 | 108 | /// Set a header (without consuming self - useful outside of method chains) 109 | pub fn set_header(&mut self, h: H) { 110 | self.inner.headers_mut().typed_insert(h); 111 | } 112 | 113 | /// Set a raw header (from the `http` crate) 114 | pub fn raw_header(mut self, name: N, key: K) -> Result 115 | where 116 | N: TryInto, 117 | K: TryInto, 118 | >::Error: Into, 119 | >::Error: Into, 120 | { 121 | self.set_raw_header(name, key)?; 122 | Ok(self) 123 | } 124 | 125 | /// Set a raw header (without consuming self) 126 | pub fn set_raw_header(&mut self, name: N, key: K) -> Result<()> 127 | where 128 | N: TryInto, 129 | K: TryInto, 130 | >::Error: Into, 131 | >::Error: Into, 132 | { 133 | self.inner 134 | .headers_mut() 135 | .insert(name.try_into()?, key.try_into()?); 136 | Ok(()) 137 | } 138 | 139 | /// Consume this response and return the inner `hyper::Response` 140 | pub fn into_inner(self) -> hyper::Response { 141 | self.inner 142 | } 143 | } 144 | 145 | /// Create a `Response` from a `hyper::Response` 146 | impl From> for Response { 147 | fn from(hyper_response: hyper::Response) -> Self { 148 | Self { 149 | inner: hyper_response, 150 | } 151 | } 152 | } 153 | 154 | /// Get a reference to the inner `hyper::Response` 155 | impl AsRef> for Response { 156 | fn as_ref(&self) -> &hyper::Response { 157 | &self.inner 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /src/router.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::Endpoint; 2 | use crate::state::State; 3 | use crate::{Request, Responder}; 4 | use hyper::{Method, StatusCode}; 5 | use route_recognizer::Params; 6 | use std::collections::HashMap; 7 | 8 | type DynEndpoint = dyn Endpoint + Send + Sync + 'static; 9 | 10 | type Recogniser = route_recognizer::Router>>; 11 | 12 | pub(crate) struct Router { 13 | methods: HashMap>, 14 | all: Recogniser, 15 | } 16 | 17 | pub(crate) struct RouteTarget<'a, S> 18 | where 19 | S: Send + Sync + 'static, 20 | { 21 | pub(crate) ep: &'a DynEndpoint, 22 | pub(crate) params: Params, 23 | } 24 | 25 | impl Router { 26 | pub(crate) fn new() -> Self { 27 | Self { 28 | methods: HashMap::new(), 29 | all: Recogniser::new(), 30 | } 31 | } 32 | 33 | pub(crate) fn add( 34 | &mut self, 35 | method: Method, 36 | path: &str, 37 | ep: impl Endpoint + Sync + Send + 'static, 38 | ) { 39 | self.methods 40 | .entry(method) 41 | .or_insert_with(route_recognizer::Router::new) 42 | .add(path, Box::new(ep)) 43 | } 44 | 45 | pub(crate) fn add_all(&mut self, path: &str, ep: impl Endpoint + Sync + Send + 'static) { 46 | self.all.add(path, Box::new(ep)) 47 | } 48 | 49 | pub(crate) fn lookup(&self, method: &Method, path: &str) -> RouteTarget { 50 | if let Some(match_) = self 51 | .methods 52 | .get(method) 53 | .and_then(|recog| recog.recognize(path).ok()) 54 | { 55 | RouteTarget { 56 | ep: &***match_.handler(), 57 | params: match_.params().clone(), // TODO - avoid this clone? 58 | } 59 | } else if let Ok(match_) = self.all.recognize(path) { 60 | RouteTarget { 61 | ep: &***match_.handler(), 62 | params: match_.params().clone(), // TODO - avoid this clone? 63 | } 64 | } else if self 65 | .methods 66 | .iter() 67 | .filter(|(k, _)| k != method) 68 | .any(|(_, recog)| recog.recognize(path).is_ok()) 69 | { 70 | RouteTarget { 71 | ep: &method_not_allowed, 72 | params: Params::new(), 73 | } 74 | } else { 75 | RouteTarget { 76 | ep: ¬_found, 77 | params: Params::new(), 78 | } 79 | } 80 | } 81 | } 82 | 83 | async fn method_not_allowed(_: Request) -> impl Responder { 84 | StatusCode::METHOD_NOT_ALLOWED 85 | } 86 | 87 | async fn not_found(_: Request) -> impl Responder { 88 | StatusCode::NOT_FOUND 89 | } 90 | -------------------------------------------------------------------------------- /src/state.rs: -------------------------------------------------------------------------------- 1 | /// State must be implemented for any type being used as the App's state 2 | /// 3 | /// State is shared by all requests, and must be safe to be shared between 4 | /// threads (Send + Sync + 'static) 5 | /// 6 | /// The state also creates the Context objects used to store request local 7 | /// data. 8 | /// Before processing a request a new context is created 9 | pub trait State: Send + Sync + 'static { 10 | /// Type of the request local context 11 | type Context: Send + Sync + 'static; 12 | 13 | /// Creating a new Context to be used for a single request 14 | fn new_context(&self) -> Self::Context; 15 | } 16 | 17 | /// Implement state for void 18 | impl State for () { 19 | type Context = (); 20 | 21 | fn new_context(&self) -> Self::Context {} 22 | } 23 | -------------------------------------------------------------------------------- /src/static_files.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::Endpoint; 2 | use crate::state::State; 3 | use crate::{Request, Response, Result}; 4 | use async_trait::async_trait; 5 | use hyper::StatusCode; 6 | use std::marker::PhantomData; 7 | use std::path::{Component, PathBuf}; 8 | use tracing::{debug, warn}; 9 | 10 | pub(crate) struct StaticFiles 11 | where 12 | S: Send + Sync + 'static, 13 | { 14 | root: PathBuf, 15 | prefix: PathBuf, 16 | _phantom: PhantomData, 17 | } 18 | 19 | impl StaticFiles 20 | where 21 | S: Send + Sync + 'static, 22 | { 23 | pub(crate) fn new(root: impl Into, prefix: impl Into) -> Self { 24 | let mut prefix = prefix.into(); 25 | // remove the final wildcard path segment 26 | prefix.pop(); 27 | 28 | Self { 29 | root: root.into(), 30 | prefix, 31 | _phantom: PhantomData, 32 | } 33 | } 34 | } 35 | 36 | #[async_trait] 37 | impl Endpoint for StaticFiles { 38 | async fn call(&self, req: Request) -> Result { 39 | let path = PathBuf::from(req.uri().path()); 40 | 41 | let mut target = self.root.clone(); 42 | 43 | for part in path.strip_prefix(&self.prefix)?.components() { 44 | match part { 45 | Component::Normal(component) => { 46 | target.push(component); 47 | } 48 | Component::Prefix(_) => { 49 | // Windows path prefixes - all are forbidden 50 | return Ok(Response::status(StatusCode::FORBIDDEN)); 51 | } 52 | Component::RootDir => { 53 | // ignored for URLs 54 | } 55 | Component::CurDir => { 56 | // skip 57 | } 58 | Component::ParentDir => { 59 | target.pop(); 60 | } 61 | } 62 | } 63 | 64 | debug!("path {:?} resolved to file {:?}", path, target); 65 | 66 | if !target.starts_with(&self.root) { 67 | warn!("path tried to navigate out of the static files root dir"); 68 | return Ok(Response::status(StatusCode::FORBIDDEN)); 69 | } 70 | 71 | if !target.is_file() { 72 | // small race condition - if the file is deleted between 73 | // here and where we open it then we're going to return a 500 74 | // instead of 404 75 | warn!("path isn't a file"); 76 | return Ok(Response::status(StatusCode::NOT_FOUND)); 77 | } 78 | 79 | Response::ok().path(target).await 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/test_client.rs: -------------------------------------------------------------------------------- 1 | use crate::test_client::test_request::TestRequest; 2 | use crate::{App, Method, State}; 3 | use hyper::{http, Uri}; 4 | use std::sync::Arc; 5 | 6 | mod test_request; 7 | mod test_response; 8 | 9 | /// A client that can send fake requests to an App and receive the responses back for unit 10 | /// and integration testing. Obtain one by calling [App::test] 11 | pub struct TestClient { 12 | app: Arc>, 13 | } 14 | 15 | impl TestClient { 16 | pub(crate) fn new(app: App) -> Self { 17 | Self { app: Arc::new(app) } 18 | } 19 | 20 | /// Prepare a GET request. Returns a TestRequest which is used to add headers and the body 21 | /// before being sent. 22 | pub fn get(&self, uri: U) -> TestRequest 23 | where 24 | Uri: TryFrom, 25 | >::Error: Into, 26 | { 27 | self.method(Method::GET, uri) 28 | } 29 | 30 | /// Prepare a POST request. Returns a TestRequest which is used to add headers and the body 31 | /// before being sent. 32 | pub fn post(&self, uri: U) -> TestRequest 33 | where 34 | Uri: TryFrom, 35 | >::Error: Into, 36 | { 37 | self.method(Method::POST, uri) 38 | } 39 | 40 | /// Prepare a PUT request. Returns a TestRequest which is used to add headers and the body 41 | /// before being sent. 42 | pub fn put(&self, uri: U) -> TestRequest 43 | where 44 | Uri: TryFrom, 45 | >::Error: Into, 46 | { 47 | self.method(Method::PUT, uri) 48 | } 49 | 50 | /// Prepare a DELETE request. Returns a TestRequest which is used to add headers and the body 51 | /// before being sent. 52 | pub fn delete(&self, uri: U) -> TestRequest 53 | where 54 | Uri: TryFrom, 55 | >::Error: Into, 56 | { 57 | self.method(Method::DELETE, uri) 58 | } 59 | 60 | /// Prepare request with the given HTTP method. Returns a TestRequest which is used to add headers 61 | /// and the body before being sent. 62 | pub fn method(&self, method: Method, uri: U) -> TestRequest 63 | where 64 | Uri: TryFrom, 65 | >::Error: Into, 66 | { 67 | TestRequest::new( 68 | self.app.clone(), 69 | http::request::Builder::new().method(method).uri(uri), 70 | ) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/test_client/test_request.rs: -------------------------------------------------------------------------------- 1 | use crate::Result; 2 | use crate::{App, State}; 3 | use headers::{Header, HeaderMapExt}; 4 | use hyper::header::{HeaderName, HeaderValue}; 5 | use hyper::{http, Body, HeaderMap}; 6 | use serde::Serialize; 7 | use std::sync::Arc; 8 | //use crate::test_client::into_body::IntoBody; 9 | use crate::test_client::test_response::TestResponse; 10 | 11 | enum PartialReq { 12 | Builder(http::request::Builder), 13 | Request(hyper::Request), 14 | } 15 | 16 | /// A fake request used for testing an App. Obtain one by calling the relevant methods on 17 | /// the [TestClient] (eg. [TestClient::get], [TestClient::post]...) 18 | /// After optionally adding headers and a body you can send the request to receive the response 19 | /// from the App. 20 | pub struct TestRequest { 21 | app: Arc>, 22 | req: PartialReq, 23 | } 24 | 25 | impl TestRequest { 26 | pub(crate) fn new(app: Arc>, builder: http::request::Builder) -> Self { 27 | Self { 28 | app, 29 | req: PartialReq::Builder(builder), 30 | } 31 | } 32 | 33 | fn headers_mut(&mut self) -> &mut HeaderMap { 34 | match &mut self.req { 35 | PartialReq::Builder(b) => b.headers_mut().expect("error getting headers"), 36 | PartialReq::Request(req) => req.headers_mut(), 37 | } 38 | } 39 | 40 | /// Set a header (from the `headers` crate) 41 | pub fn header(mut self, h: H) -> Self { 42 | self.headers_mut().typed_insert(h); 43 | self 44 | } 45 | 46 | /// Set a raw header (from the `http` crate) 47 | pub fn raw_header(mut self, name: N, key: K) -> Result 48 | where 49 | N: TryInto, 50 | K: TryInto, 51 | >::Error: Into, 52 | >::Error: Into, 53 | { 54 | self.headers_mut().insert(name.try_into()?, key.try_into()?); 55 | Ok(self) 56 | } 57 | 58 | /// Add a body to this request. 59 | pub fn body(mut self, body: impl Into) -> Result { 60 | self.req = match self.req { 61 | PartialReq::Builder(b) => PartialReq::Request(b.body(body.into())?), 62 | PartialReq::Request(_req) => { 63 | panic!("body already set!") 64 | } 65 | }; 66 | Ok(self) 67 | } 68 | 69 | /// Add a JSON encoded body to this request, and set the `Content-Type` header 70 | /// to `application/json` 71 | pub fn json(self, data: impl Serialize) -> Result { 72 | let body = serde_json::to_string(&data)?; 73 | self.body(body) 74 | } 75 | 76 | /// Send the request to the App and receive the response. 77 | pub async fn send(self) -> Result { 78 | let req = match self.req { 79 | PartialReq::Builder(b) => b.body(Body::empty())?, 80 | PartialReq::Request(r) => r, 81 | }; 82 | 83 | let addr = "127.0.0.1:8080".parse().expect("socket addr is invalid?"); 84 | let resp = App::serve_one_req(self.app, req, addr).await?; 85 | Ok(TestResponse::from(resp)) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/test_client/test_response.rs: -------------------------------------------------------------------------------- 1 | use crate::{Result, StatusCode}; 2 | use hyper::{body::Buf, Body, Response}; 3 | use serde::de::DeserializeOwned; 4 | 5 | /// The response returned from the test client 6 | /// This currently has an AsRef implementation to get the inner hyper response 7 | /// but more helper methods will be added over time to reduce the need for touching 8 | /// the raw hyper types. 9 | pub struct TestResponse { 10 | inner: hyper::Response, 11 | } 12 | 13 | impl From> for TestResponse { 14 | fn from(resp: Response) -> Self { 15 | Self { inner: resp } 16 | } 17 | } 18 | 19 | impl TestResponse { 20 | /// Get the status code 21 | pub fn status(&self) -> StatusCode { 22 | self.inner.status() 23 | } 24 | 25 | /// Get the request body as UTF-8 data in a String 26 | pub async fn body_string(&mut self) -> Result { 27 | let bytes = hyper::body::to_bytes(self.inner.body_mut()).await?; 28 | Ok(String::from_utf8(bytes.to_vec())?) 29 | } 30 | 31 | /// Get the request body as bytes in a Vec 32 | pub async fn body_bytes(&mut self) -> Result> { 33 | let bytes = hyper::body::to_bytes(self.inner.body_mut()).await?; 34 | Ok(bytes.to_vec()) 35 | } 36 | 37 | /// Get the request body by decoding JSON. Any type that implements Deserialize can be used. 38 | pub async fn body_json(&mut self) -> Result { 39 | let buffer = hyper::body::aggregate(self.inner.body_mut()).await?; 40 | let data = serde_json::from_reader(buffer.reader())?; 41 | Ok(data) 42 | } 43 | } 44 | 45 | impl AsRef> for TestResponse { 46 | fn as_ref(&self) -> &hyper::Response { 47 | &self.inner 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/ws.rs: -------------------------------------------------------------------------------- 1 | use crate::endpoint::Endpoint; 2 | use crate::state::State; 3 | use crate::{Request, Response, Result}; 4 | use async_trait::async_trait; 5 | use futures_util::stream::{SplitSink, SplitStream}; 6 | use futures_util::{SinkExt, StreamExt, TryStreamExt}; 7 | use hyper::upgrade::Upgraded; 8 | use hyper::StatusCode; 9 | use std::future::Future; 10 | use std::marker::PhantomData; 11 | use std::sync::Arc; 12 | use tokio_tungstenite::tungstenite::Message; 13 | use tokio_tungstenite::WebSocketStream; 14 | use tracing::trace; 15 | 16 | /// An endpoint for accepting a websocket connection. 17 | /// Typically constructed by the `Route::ws` method. 18 | #[derive(Debug)] 19 | pub struct WsEndpoint 20 | where 21 | S: State + Send + Sync + 'static, 22 | H: Send + Sync + 'static + Fn(Request, WebSocketSender, WebSocketReceiver) -> F, 23 | F: Future> + Send + 'static, 24 | { 25 | handler: Arc, 26 | _phantoms: PhantomData, 27 | } 28 | 29 | /// Create a websocket endpoint. 30 | /// Typically called by the `Route::ws` method. 31 | pub fn endpoint(handler: H) -> WsEndpoint 32 | where 33 | S: State + Send + Sync + 'static, 34 | H: Send + Sync + 'static + Fn(Request, WebSocketSender, WebSocketReceiver) -> F, 35 | F: Future> + Send + 'static, 36 | { 37 | WsEndpoint { 38 | handler: Arc::new(handler), 39 | _phantoms: PhantomData, 40 | } 41 | } 42 | 43 | #[async_trait] 44 | impl Endpoint for WsEndpoint 45 | where 46 | S: State, 47 | H: Send + Sync + 'static + Fn(Request, WebSocketSender, WebSocketReceiver) -> F, 48 | F: Future> + Send + 'static, 49 | { 50 | async fn call(&self, req: Request) -> Result { 51 | let handler = self.handler.clone(); 52 | 53 | let res = upgrade_connection(req, handler).await; 54 | 55 | Ok(res) 56 | } 57 | } 58 | 59 | async fn upgrade_connection(mut req: Request, handler: Arc) -> Response 60 | where 61 | S: State, 62 | H: Send + Sync + 'static + Fn(Request, WebSocketSender, WebSocketReceiver) -> F, 63 | F: Future> + Send + 'static, 64 | { 65 | // TODO - check various headers 66 | 67 | if let Some(conn) = req.header::() { 68 | if !conn.contains(hyper::header::UPGRADE) { 69 | return Response::status(StatusCode::BAD_REQUEST); 70 | } 71 | } else { 72 | return Response::status(StatusCode::BAD_REQUEST); 73 | } 74 | 75 | if let Some(upgrade) = req.header::() { 76 | if upgrade != headers::Upgrade::websocket() { 77 | return Response::status(StatusCode::BAD_REQUEST); 78 | } 79 | } else { 80 | return Response::status(StatusCode::BAD_REQUEST); 81 | } 82 | 83 | let key = match req.header::() { 84 | Some(header) => header, 85 | None => return Response::status(StatusCode::BAD_REQUEST), 86 | }; 87 | 88 | let res = Response::status(StatusCode::SWITCHING_PROTOCOLS) 89 | .header(headers::Upgrade::websocket()) 90 | .header(headers::Connection::upgrade()) 91 | .header(headers::SecWebsocketAccept::from(key)); 92 | 93 | trace!("upgrading connection to websocket"); 94 | 95 | tokio::spawn(async move { 96 | let upgraded = hyper::upgrade::on(req.as_inner_mut()) 97 | .await 98 | .expect("websocket upgrade failed - TODO report this error"); 99 | 100 | let ws = WebSocketStream::from_raw_socket( 101 | upgraded, 102 | tokio_tungstenite::tungstenite::protocol::Role::Server, 103 | None, 104 | ) 105 | .await; 106 | 107 | let (tx, rx) = ws.split(); 108 | let res = (handler)( 109 | req, 110 | WebSocketSender { inner: tx }, 111 | WebSocketReceiver { inner: rx }, 112 | ) 113 | .await; 114 | 115 | match res { 116 | Ok(()) => trace!("websocket handler returned"), 117 | Err(e) => trace!("websocket handler returned an error: {}", e), 118 | }; 119 | }); 120 | 121 | res 122 | } 123 | 124 | /// The sending half of the websocket connection 125 | pub struct WebSocketSender { 126 | inner: SplitSink, Message>, 127 | } 128 | 129 | impl WebSocketSender { 130 | /// Send a message over the websocket 131 | pub async fn send(&mut self, msg: Message) -> Result<()> { 132 | self.inner.send(msg).await?; 133 | Ok(()) 134 | } 135 | } 136 | 137 | /// The receiving half of the websocket connection 138 | pub struct WebSocketReceiver { 139 | inner: SplitStream>, 140 | } 141 | 142 | impl WebSocketReceiver { 143 | /// Receive a message from the websocket 144 | pub async fn recv(&mut self) -> Result> { 145 | let msg = self.inner.try_next().await?; 146 | Ok(msg) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /tests/test_client.rs: -------------------------------------------------------------------------------- 1 | use highnoon::{App, Json, Request, StatusCode}; 2 | use serde_json::{json, Value}; 3 | 4 | fn make_app() -> App<()> { 5 | let mut app = App::new(()); 6 | 7 | app.at("/greeting").get(|_req| async { "Hello World!" }); 8 | 9 | app.at("/reverse").get(|mut req: Request<()>| async move { 10 | let mut data = req.body_bytes().await?; 11 | data.reverse(); 12 | Ok(data) 13 | }); 14 | 15 | app.at("/json").get(|mut req: Request<()>| async move { 16 | let data: Value = req.body_json().await?; 17 | let greeting = data 18 | .get("greeting") 19 | .and_then(|val| val.as_str()) 20 | .map(|s| s.to_owned()); 21 | Ok(Json(greeting)) 22 | }); 23 | 24 | app 25 | } 26 | 27 | #[tokio::main] 28 | #[test] 29 | pub async fn test_greeting() -> highnoon::Result<()> { 30 | let tc = make_app().test(); 31 | 32 | let mut resp = tc.get("/greeting").send().await?; 33 | assert_eq!(resp.status(), StatusCode::OK); 34 | assert_eq!(resp.body_string().await?, "Hello World!"); 35 | 36 | Ok(()) 37 | } 38 | 39 | #[tokio::main] 40 | #[test] 41 | pub async fn test_reverse() -> highnoon::Result<()> { 42 | let tc = make_app().test(); 43 | 44 | let mut resp = tc.get("/reverse").body("Hello World!")?.send().await?; 45 | 46 | assert_eq!(resp.status(), StatusCode::OK); 47 | assert_eq!(resp.body_string().await?, "!dlroW olleH"); 48 | 49 | Ok(()) 50 | } 51 | 52 | #[tokio::main] 53 | #[test] 54 | pub async fn test_json() -> highnoon::Result<()> { 55 | let tc = make_app().test(); 56 | 57 | let mut resp = tc 58 | .get("/json") 59 | .json(json!({ 60 | "greeting": "Hello World!" 61 | }))? 62 | .send() 63 | .await?; 64 | 65 | assert_eq!(resp.status(), StatusCode::OK); 66 | assert_eq!(resp.body_string().await?, "\"Hello World!\""); 67 | 68 | Ok(()) 69 | } 70 | 71 | #[tokio::main] 72 | #[test] 73 | pub async fn test_404() -> highnoon::Result<()> { 74 | let tc = make_app().test(); 75 | 76 | let resp = tc.get("/no_such_route").send().await?; 77 | 78 | assert_eq!(resp.status(), StatusCode::NOT_FOUND); 79 | 80 | Ok(()) 81 | } 82 | 83 | #[tokio::main] 84 | #[test] 85 | pub async fn test_method_not_allowed() -> highnoon::Result<()> { 86 | let tc = make_app().test(); 87 | 88 | let resp = tc.delete("/greeting").send().await?; 89 | 90 | assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); 91 | 92 | Ok(()) 93 | } 94 | --------------------------------------------------------------------------------