├── .github ├── media │ └── logo.png └── workflows │ └── rust.yml ├── .gitignore ├── .travis.yml ├── CODE_OF_CONDUCT.md ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── app_state.rs ├── hello.rs ├── hello_handler.rs ├── json.rs ├── main.rs └── middleware │ ├── logger_example.rs │ └── mod.rs ├── screenshot └── serve.png └── src ├── app.rs ├── context.rs ├── error.rs ├── error └── obsidian_error.rs ├── lib.rs ├── middleware.rs ├── middleware └── logger.rs ├── router.rs └── router ├── handler.rs ├── req_deserializer.rs ├── resource.rs ├── responder.rs ├── response.rs ├── response_body.rs ├── route.rs └── route_trie.rs /.github/media/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obsidian-rs/obsidian/bbc25ea05c97cfe621e3000763e78dd4f9cb1fdb/.github/media/logo.png -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Obsidian Action 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - develop 8 | - release/* 9 | pull_request: 10 | branches: 11 | - master 12 | - develop 13 | 14 | jobs: 15 | build_stable: 16 | name: Stable 17 | runs-on: ${{ matrix.os }} 18 | strategy: 19 | matrix: 20 | os: [ubuntu-latest, windows-latest, macOS-latest] 21 | rust: [stable] 22 | 23 | steps: 24 | - uses: hecrj/setup-rust-action@v1 25 | with: 26 | rust-version: ${{ matrix.rust }} 27 | components: rustfmt 28 | - uses: actions/checkout@v2 29 | - name: Build 30 | run: cargo build --verbose 31 | - name: Install clippy 32 | run: rustup component add clippy 33 | - name: Check code format 34 | run: cargo fmt --all -- --check 35 | - name: Run clippy 36 | run: cargo clippy --all-targets --all-features -- -D warnings 37 | - name: Run tests 38 | run: cargo test --verbose 39 | 40 | build_nightly: 41 | name: Nightly 42 | runs-on: ${{ matrix.os }} 43 | strategy: 44 | matrix: 45 | os: [ubuntu-latest, windows-latest, macOS-latest] 46 | rust: [nightly] 47 | 48 | steps: 49 | - uses: hecrj/setup-rust-action@v1 50 | with: 51 | rust-version: ${{ matrix.rust }} 52 | - uses: actions/checkout@v2 53 | - name: Build 54 | run: cargo +nightly build --verbose 55 | # - name: Install clippy 56 | # run: rustup component add clippy 57 | # - name: Run clippy 58 | # run: cargo clippy --all-targets --all-features -- -D warnings 59 | - name: Run tests 60 | run: cargo +nightly test --verbose 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | /target/ 4 | 5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 7 | Cargo.lock 8 | 9 | # These are backup files generated by rustfmt 10 | **/*.rs.bk 11 | 12 | .DS_Store 13 | 14 | # vscode local config 15 | /.vscode/ -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: rust 2 | rust: 3 | - stable 4 | - beta 5 | - nightly 6 | branches: 7 | only: 8 | - master 9 | - develop 10 | matrix: 11 | allow_failures: 12 | - rust: nightly 13 | fast_finish: true 14 | cache: cargo 15 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | 130 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = 'obsidian' 3 | version = '0.3.0-dev' 4 | authors = [ 5 | 'Gan Jun Kai ', 6 | 'Wai Pai Lee ', 7 | ] 8 | edition = '2018' 9 | description = 'Ergonomic async http framework for amazing, reliable and efficient web' 10 | readme = "README.md" 11 | homepage = "https://obsidian-rs.github.io" 12 | repository = 'https://github.com/obsidian-rs/obsidian' 13 | license = 'MIT' 14 | keywords = [ 15 | 'obsidian', 16 | 'async', 17 | 'http', 18 | 'web', 19 | 'framework' 20 | ] 21 | categories = ["asynchronous", "web-programming::http-server", "network-programming"] 22 | 23 | [[example]] 24 | name = 'example' 25 | path = 'examples/main.rs' 26 | 27 | [[example]] 28 | name = 'hello' 29 | path = 'examples/hello.rs' 30 | 31 | [[example]] 32 | name = 'hello_handler' 33 | path = 'examples/hello_handler.rs' 34 | 35 | [[example]] 36 | name = 'json' 37 | path = 'examples/json.rs' 38 | 39 | [[example]] 40 | name = 'app_state' 41 | path = 'examples/app_state.rs' 42 | 43 | [dependencies] 44 | hyper = { version = "0.14.9", features = [ "full" ] } 45 | http = "0.2.4" 46 | serde = { version = "1.0.126", features = [ "derive" ] } 47 | serde_json = "1.0.64" 48 | url = "2.2.2" 49 | async-std = "1.9.0" 50 | tokio = { version = "1.7.0", features = [ "macros", "rt-multi-thread" ] } 51 | async-trait = "0.1.50" 52 | colored = "2.0.0" 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Gan Jun Kai & Wai Pai Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | Obsidian Logo 4 | 5 |

6 | Obsidian 7 |

8 |

9 | 10 |

Obsidian is an ergonomic Rust async http framework for reliable and efficient web.

11 | 12 |
13 | 14 | Obsidian crate 15 | 16 | 17 | GitHub Actions status 18 | 19 |
20 | 21 |
22 | Obsidian serve 23 |
24 | 25 | ## Get Started 26 | ```toml 27 | [dependencies] 28 | # add these 2 dependencies in Cargo.toml file 29 | obsidian = "0.2.2" 30 | tokio = "0.2.21" 31 | ``` 32 | 33 | ## Hello World 34 | 35 | ```rust 36 | use obsidian::{context::Context, App}; 37 | 38 | #[tokio::main] 39 | async fn main() { 40 | let mut app: App = App::new(); 41 | 42 | app.get("/", |ctx: Context| async { ctx.build("Hello World").ok() }); 43 | 44 | app.listen(3000).await; 45 | } 46 | ``` 47 | 48 | ## Hello World (with handler function) 49 | 50 | ```rust 51 | use obsidian::{context::Context, App, ContextResult}; 52 | 53 | async fn hello_world(ctx: Context) -> ContextResult { 54 | ctx.build("Hello World").ok() 55 | } 56 | 57 | 58 | #[tokio::main] 59 | async fn main() { 60 | let mut app: App = App::new(); 61 | 62 | app.get("/", hello_world); 63 | 64 | app.listen(3000).await; 65 | } 66 | ``` 67 | 68 | ## JSON Response 69 | 70 | ```rust 71 | use obsidian::{context::Context, App, ContextResult}; 72 | use serde::*; 73 | 74 | async fn get_user(ctx: Context) -> ContextResult { 75 | #[derive(Serialize, Deserialize)] 76 | struct User { 77 | name: String, 78 | }; 79 | 80 | let user = User { 81 | name: String::from("Obsidian"), 82 | }; 83 | ctx.build_json(user).ok() 84 | } 85 | 86 | #[tokio::main] 87 | async fn main() { 88 | let mut app: App = App::new(); 89 | 90 | app.get("/user", get_user); 91 | 92 | app.listen(3000).await; 93 | } 94 | 95 | ``` 96 | 97 | ## Example Files 98 | 99 | Example are located in `example/main.rs`. 100 | 101 | ## Run Example 102 | 103 | ``` 104 | cargo run --example example 105 | ``` 106 | 107 | ## Current State 108 | 109 | NOT READY FOR PRODUCTION YET! 110 | -------------------------------------------------------------------------------- /examples/app_state.rs: -------------------------------------------------------------------------------- 1 | use obsidian::{context::Context, App, ObsidianError}; 2 | 3 | #[derive(Clone)] 4 | pub struct AppState { 5 | pub db_connection_string: String, 6 | } 7 | 8 | #[tokio::main] 9 | async fn main() { 10 | let mut app: App = App::new(); 11 | 12 | app.set_app_state(AppState { 13 | db_connection_string: "localhost:1433".to_string(), 14 | }); 15 | 16 | app.get("/", |ctx: Context| async { 17 | let app_state = ctx.get::().ok_or(ObsidianError::NoneError)?; 18 | let res = Some(format!( 19 | "connection string: {}", 20 | &app_state.db_connection_string 21 | )); 22 | 23 | ctx.build(res).ok() 24 | }); 25 | 26 | app.listen(3000).await; 27 | } 28 | -------------------------------------------------------------------------------- /examples/hello.rs: -------------------------------------------------------------------------------- 1 | use obsidian::{context::Context, App}; 2 | 3 | #[tokio::main] 4 | async fn main() { 5 | let mut app: App = App::new(); 6 | 7 | app.get("/", |ctx: Context| async { ctx.build("Hello World").ok() }); 8 | 9 | app.listen(3000).await; 10 | } 11 | -------------------------------------------------------------------------------- /examples/hello_handler.rs: -------------------------------------------------------------------------------- 1 | use obsidian::{context::Context, App, ContextResult}; 2 | 3 | async fn hello_world(ctx: Context) -> ContextResult { 4 | ctx.build("Hello World").ok() 5 | } 6 | 7 | #[tokio::main] 8 | async fn main() { 9 | let mut app: App = App::new(); 10 | 11 | app.get("/", hello_world); 12 | 13 | app.listen(3000).await; 14 | } 15 | -------------------------------------------------------------------------------- /examples/json.rs: -------------------------------------------------------------------------------- 1 | use obsidian::{context::Context, App, ContextResult}; 2 | use serde::*; 3 | 4 | async fn get_user(ctx: Context) -> ContextResult { 5 | #[derive(Serialize, Deserialize)] 6 | struct User { 7 | name: String, 8 | } 9 | 10 | let user = User { 11 | name: String::from("Obsidian"), 12 | }; 13 | ctx.build_json(user).ok() 14 | } 15 | 16 | #[tokio::main] 17 | async fn main() { 18 | let mut app: App = App::new(); 19 | 20 | app.get("/user", get_user); 21 | 22 | app.listen(3000).await; 23 | } 24 | -------------------------------------------------------------------------------- /examples/main.rs: -------------------------------------------------------------------------------- 1 | mod middleware; 2 | 3 | use middleware::logger_example::*; 4 | use serde::*; 5 | use std::{fmt, fmt::Display}; 6 | 7 | use obsidian::{ 8 | context::Context, 9 | router::{header, Responder, Response, Router}, 10 | App, ObsidianError, StatusCode, 11 | }; 12 | 13 | // Testing example 14 | #[derive(Serialize, Deserialize, Debug)] 15 | struct Point { 16 | x: i32, 17 | y: i32, 18 | } 19 | 20 | #[derive(Serialize, Deserialize, Debug)] 21 | struct JsonTest { 22 | title: String, 23 | content: String, 24 | } 25 | 26 | #[derive(Serialize, Deserialize, Debug)] 27 | struct ParamTest { 28 | test: Vec, 29 | test2: String, 30 | } 31 | 32 | #[derive(Serialize, Deserialize, Debug)] 33 | struct User { 34 | name: String, 35 | age: i8, 36 | } 37 | 38 | impl Display for JsonTest { 39 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 40 | write!(f, "{{title: {}, content: {}}}", self.title, self.content) 41 | } 42 | } 43 | 44 | // async fn responder_json(mut ctx: Context) -> impl Responder { 45 | // let person: Person = ctx.json().await?; 46 | 47 | // person.age += 1; 48 | 49 | // Ok(response::json(person)) 50 | // } 51 | 52 | // async fn responder_obsidian_error(mut ctx: Context) -> impl Responder { 53 | // let json: JsonTest = ctx.json().await?; 54 | // println!("{}", json); 55 | // Ok(response::json(json, StatusCode::OK)) 56 | // } 57 | 58 | // fn responder_with_header(ctx: Context: Context) -> impl Responder { 59 | // let headers = vec![ 60 | // ("X-Custom-Header-4", "custom-value-4"), 61 | // ("X-Custom-Header-5", "custom-value-5"), 62 | // ]; 63 | 64 | // "here" 65 | // .header("Content-Type", "application/json") 66 | // .header("X-Custom-Header", "custom-value") 67 | // .header("X-Custom-Header-2", "custom-value-2") 68 | // .header("X-Custom-Header-3", "custom-value-3") 69 | // .set_headers(headers) 70 | // .status(StatusCode::CREATED) 71 | // } 72 | 73 | #[tokio::main] 74 | async fn main() { 75 | let mut app: App = App::new(); 76 | 77 | app.get("/", |ctx: Context| async { 78 | ctx.build(Response::ok().html("

Hello Obsidian

")).ok() 79 | }); 80 | 81 | app.get("/json", |ctx: Context| async { 82 | let point = Point { x: 1, y: 2 }; 83 | 84 | ctx.build_json(point) 85 | .with_status(StatusCode::OK) 86 | .with_header(header::AUTHORIZATION, "token") 87 | .with_header_str("X-Custom-Header", "Custom header value") 88 | .ok() 89 | }); 90 | 91 | app.get("/user", |mut ctx: Context| async { 92 | #[derive(Serialize, Deserialize, Debug)] 93 | struct QueryString { 94 | id: String, 95 | status: String, 96 | } 97 | 98 | let params = match ctx.query_params::() { 99 | Ok(params) => params, 100 | Err(error) => { 101 | println!("error: {}", error); 102 | QueryString { 103 | id: String::from(""), 104 | status: String::from(""), 105 | } 106 | } 107 | }; 108 | 109 | println!("params: {:?}", params); 110 | 111 | ctx.build("").ok() 112 | }); 113 | 114 | app.patch("/patch-here", |ctx: Context| async { 115 | ctx.build("Here is patch request").ok() 116 | }); 117 | 118 | app.get("/json-with-headers", |ctx: Context| async { 119 | let point = Point { x: 1, y: 2 }; 120 | 121 | let custom_headers = vec![ 122 | ("X-Custom-Header-1", "Custom header 1"), 123 | ("X-Custom-Header-2", "Custom header 2"), 124 | ("X-Custom-Header-3", "Custom header 3"), 125 | ]; 126 | 127 | let standard_headers = vec![ 128 | (header::AUTHORIZATION, "token"), 129 | (header::ACCEPT_CHARSET, "utf-8"), 130 | ]; 131 | 132 | ctx.build( 133 | Response::created() 134 | .with_headers(standard_headers) 135 | .with_headers_str(custom_headers) 136 | .json(point), 137 | ) 138 | .ok() 139 | }); 140 | 141 | app.get("/string-with-headers", |ctx: Context| async { 142 | let custom_headers = vec![ 143 | ("X-Custom-Header-1", "Custom header 1"), 144 | ("X-Custom-Header-2", "Custom header 2"), 145 | ("X-Custom-Header-3", "Custom header 3"), 146 | ]; 147 | 148 | let standard_headers = vec![ 149 | (header::AUTHORIZATION, "token"), 150 | (header::ACCEPT_CHARSET, "utf-8"), 151 | ]; 152 | 153 | ctx.build("Hello World") 154 | .with_headers(standard_headers) 155 | .with_headers_str(custom_headers) 156 | .ok() 157 | }); 158 | 159 | app.get("/empty-body", |ctx: Context| async { 160 | ctx.build(StatusCode::OK).ok() 161 | }); 162 | 163 | app.get("/vec", |ctx: Context| async { 164 | ctx.build(vec![1, 2, 3]) 165 | .with_status(StatusCode::CREATED) 166 | .ok() 167 | }); 168 | 169 | app.get("/String", |ctx: Context| async { 170 | ctx.build("

This is a String

".to_string()).ok() 171 | }); 172 | 173 | app.get("/test/radix", |ctx: Context| async { 174 | ctx.build("

Test radix

".to_string()).ok() 175 | }); 176 | 177 | app.get("/team/radix", |ctx: Context| async { 178 | ctx.build("Team radix".to_string()).ok() 179 | }); 180 | 181 | app.get("/test/radix2", |ctx: Context| async { 182 | ctx.build("

Test radix2

".to_string()).ok() 183 | }); 184 | 185 | app.get("/jsontest", |ctx: Context| async { 186 | ctx.build_file("./testjson.html").await.ok() 187 | }); 188 | 189 | app.get("/jsan", |ctx: Context| async { 190 | ctx.build("

jsan

".to_string()).ok() 191 | }); 192 | 193 | app.get("/test/wildcard/*", |ctx: Context| async move { 194 | let res = format!( 195 | "{}
{}", 196 | "

Test wildcard

".to_string(), 197 | ctx.uri().path() 198 | ); 199 | 200 | ctx.build(res).ok() 201 | }); 202 | 203 | app.get("router/test", |ctx: Context| async move { 204 | let result = ctx 205 | .extensions() 206 | .get::() 207 | .ok_or(ObsidianError::NoneError)?; 208 | 209 | dbg!(&result.0); 210 | 211 | let res = Some(format!( 212 | "{}
{}", 213 | "

router test get

".to_string(), 214 | ctx.uri().path() 215 | )); 216 | 217 | ctx.build(res).ok() 218 | }); 219 | app.post("router/test", |ctx: Context| async move { 220 | let res = format!( 221 | "{}
{}", 222 | "

router test post

".to_string(), 223 | ctx.uri().path() 224 | ); 225 | 226 | ctx.build(res).ok() 227 | }); 228 | app.put("router/test", |ctx: Context| async move { 229 | let res = format!( 230 | "{}
{}", 231 | "

router test put

".to_string(), 232 | ctx.uri().path() 233 | ); 234 | 235 | ctx.build(res).ok() 236 | }); 237 | app.delete("router/test", |ctx: Context| async move { 238 | let res = format!( 239 | "{}
{}", 240 | "

router test delete

".to_string(), 241 | ctx.uri().path() 242 | ); 243 | 244 | ctx.build(res).ok() 245 | }); 246 | 247 | app.get("route/diff_route", |ctx: Context| async move { 248 | let res = format!( 249 | "{}
{}", 250 | "

route diff get

".to_string(), 251 | ctx.uri().path() 252 | ); 253 | 254 | ctx.build(res).ok() 255 | }); 256 | 257 | app.scope("admin", |router: &mut Router| { 258 | router.get("test", |ctx: Context| async move { 259 | ctx.build("Hello admin test").ok() 260 | }); 261 | 262 | router.get("test2", |ctx: Context| async move { 263 | ctx.build("Hello admin test 2").ok() 264 | }); 265 | }); 266 | 267 | app.scope("form", |router: &mut Router| { 268 | router.get("/formtest", |ctx: Context| async move { 269 | ctx.build_file("/.test.html").await.ok() 270 | }); 271 | }); 272 | 273 | // form_router.post("/formtest", |mut ctx: Context| async move{ 274 | // let param_test: ParamTest = ctx.form().await?; 275 | 276 | // dbg!(¶m_test); 277 | 278 | // Ok(response::json(param_test, StatusCode::OK)) 279 | // }); 280 | 281 | // param_router.get("/paramtest/:id", |ctx: Context| async move { 282 | // let param_test: i32 = ctx.param("id")?; 283 | 284 | // dbg!(¶m_test); 285 | 286 | // Ok(response::json(param_test, StatusCode::OK)) 287 | // }); 288 | // 289 | // param_router.get("/paramtest/:id/test", |ctx: Context| async move { 290 | // let mut param_test: i32 = ctx.param("id").unwrap(); 291 | // param_test = param_test * 10; 292 | 293 | // dbg!(¶m_test); 294 | 295 | // Ok(response::json(param_test, StatusCode::OK)) 296 | // }); 297 | 298 | let logger_example = middleware::logger_example::LoggerExample::new(); 299 | app.use_service(logger_example); 300 | 301 | app.scope("params", |router: &mut Router| { 302 | router.get("/test-next-wild/*", |ctx: Context| async { 303 | ctx.build("

test next wild

".to_string()).ok() 304 | }); 305 | 306 | router.get("/*", |ctx: Context| async { 307 | ctx.build( 308 | "

404 Not Found

" 309 | .to_string() 310 | .with_status(StatusCode::NOT_FOUND), 311 | ) 312 | .ok() 313 | }); 314 | }); 315 | 316 | app.use_static_to("/files/", "/assets/"); 317 | 318 | app.listen(3000).await; 319 | } 320 | -------------------------------------------------------------------------------- /examples/middleware/logger_example.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | 3 | #[cfg(debug_assertions)] 4 | use colored::*; 5 | 6 | use obsidian::{context::Context, middleware::Middleware, ContextResult, EndpointExecutor}; 7 | 8 | #[derive(Default)] 9 | pub struct LoggerExample {} 10 | 11 | pub struct LoggerExampleData(pub String); 12 | 13 | impl LoggerExample { 14 | #[allow(dead_code)] 15 | pub fn new() -> Self { 16 | LoggerExample {} 17 | } 18 | } 19 | 20 | #[async_trait] 21 | impl Middleware for LoggerExample { 22 | async fn handle<'a>( 23 | &'a self, 24 | mut context: Context, 25 | ep_executor: EndpointExecutor<'a>, 26 | ) -> ContextResult { 27 | #[cfg(debug_assertions)] 28 | println!("{}", "[debug] Inside middleware".blue()); 29 | 30 | println!( 31 | "{} {}{}", 32 | context.method(), 33 | context.headers().get("host").unwrap().to_str().unwrap(), 34 | context.uri(), 35 | ); 36 | 37 | context.add(LoggerExampleData("This is logger data".to_string())); 38 | 39 | ep_executor.next(context).await 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /examples/middleware/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod logger_example; 2 | -------------------------------------------------------------------------------- /screenshot/serve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/obsidian-rs/obsidian/bbc25ea05c97cfe621e3000763e78dd4f9cb1fdb/screenshot/serve.png -------------------------------------------------------------------------------- /src/app.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use colored::*; 4 | use hyper::{ 5 | header, 6 | service::{make_service_fn, service_fn}, 7 | Body, Request, Response, Server, StatusCode, 8 | }; 9 | 10 | use crate::context::Context; 11 | use crate::error::ObsidianError; 12 | use crate::middleware::Middleware; 13 | use crate::router::{ContextResult, Handler, RouteValueResult, Router}; 14 | 15 | use crate::middleware::logger::Logger; 16 | 17 | #[derive(Clone)] 18 | pub struct DefaultAppState {} 19 | 20 | pub struct App 21 | where 22 | T: Clone + Send + Sync + 'static, 23 | { 24 | router: Router, 25 | app_state: Option, 26 | } 27 | 28 | impl Default for App 29 | where 30 | T: Clone + Send + Sync + 'static, 31 | { 32 | /// create an `Obsidian` app with default middlwares: [`Logger`] 33 | fn default() -> Self { 34 | let mut app = App { 35 | router: Router::new(), 36 | app_state: None, 37 | }; 38 | let logger = Logger::new(); 39 | app.use_service(logger); 40 | app 41 | } 42 | } 43 | 44 | impl App 45 | where 46 | T: Clone + Send + Sync + 'static, 47 | { 48 | pub fn new() -> Self { 49 | App { 50 | router: Router::new(), 51 | app_state: None, 52 | } 53 | } 54 | 55 | pub fn get(&mut self, path: &str, handler: impl Handler) { 56 | self.router.get(path, handler); 57 | } 58 | 59 | pub fn post(&mut self, path: &str, handler: impl Handler) { 60 | self.router.post(path, handler); 61 | } 62 | 63 | pub fn put(&mut self, path: &str, handler: impl Handler) { 64 | self.router.put(path, handler); 65 | } 66 | 67 | pub fn patch(&mut self, path: &str, handler: impl Handler) { 68 | self.router.patch(path, handler); 69 | } 70 | 71 | pub fn delete(&mut self, path: &str, handler: impl Handler) { 72 | self.router.delete(path, handler); 73 | } 74 | 75 | /// Register a nested router for the app 76 | /// 77 | /// Example: 78 | /// ``` 79 | /// use obsidian::{App, router::Router, context::Context}; 80 | /// 81 | /// let mut app: App = App::new(); 82 | /// 83 | /// app.scope("admin", |router: &mut Router| { 84 | /// router.get("list", |ctx: Context| async move { 85 | /// ctx.build("Admin list here").ok() 86 | /// }); 87 | /// }); 88 | /// ``` 89 | /// 90 | pub fn scope(&mut self, name: &str, scoped_routes: impl Fn(&mut Router)) { 91 | let mut new_router = Router::new(); 92 | 93 | scoped_routes(&mut new_router); 94 | self.use_router(format!("/{}", name).as_ref(), new_router); 95 | } 96 | 97 | /// Apply middleware in the provided route 98 | pub fn use_service_to(&mut self, path: &str, middleware: impl Middleware) { 99 | self.router.use_service_to(path, middleware); 100 | } 101 | 102 | /// Apply middleware in current relative route 103 | pub fn use_service(&mut self, middleware: impl Middleware) { 104 | self.router.use_service(middleware); 105 | } 106 | 107 | /// Apply route handler in current relative route 108 | pub fn use_router(&mut self, path: &str, router: Router) { 109 | self.router.use_router(path, router); 110 | } 111 | 112 | /// Serve static files by the virtual path as the route and directory path as the server file path 113 | pub fn use_static_to(&mut self, virtual_path: &str, dir_path: &str) { 114 | self.router.use_static_to(virtual_path, dir_path); 115 | } 116 | 117 | /// Serve static files by the directory path as the route and server file path 118 | pub fn use_static(&mut self, dir_path: &str) { 119 | self.router.use_static(dir_path); 120 | } 121 | 122 | /// Set app state. The app state must impl Clone. 123 | /// The app state will be passed into endpoint handler context dynamic data. 124 | /// 125 | /// # Example 126 | /// ``` 127 | /// use obsidian::App; 128 | /// 129 | /// #[derive(Clone)] 130 | /// struct AppState { 131 | /// db_connection: String, 132 | /// } 133 | /// 134 | /// let mut app: App = App::new(); 135 | /// app.set_app_state(AppState{ 136 | /// db_connection: "localhost:1433".to_string(), 137 | /// }); 138 | /// ``` 139 | pub fn set_app_state(&mut self, app_state: T) { 140 | self.app_state = Some(app_state); 141 | } 142 | 143 | pub async fn listen(self, port: u16) { 144 | let app_server: AppServer = AppServer { 145 | router: self.router, 146 | }; 147 | let app_state = self.app_state; 148 | 149 | let service = make_service_fn(move |_| { 150 | let server_clone = app_server.clone(); 151 | let app_state = app_state.clone(); 152 | 153 | async { 154 | Ok::<_, hyper::Error>(service_fn(move |req| { 155 | let route_value = server_clone.router.search_route(req.uri().path()); 156 | 157 | AppServer::resolve_endpoint(req, route_value, app_state.clone()) 158 | })) 159 | } 160 | }); 161 | 162 | let addr = ([127, 0, 0, 1], port).into(); 163 | let server = Server::bind(&addr).serve(service); 164 | 165 | let logo = r#" 166 | 167 | .oooooo. oooooooooo. .oooooo..o ooooo oooooooooo. ooooo .o. ooooo ooo 168 | d8P' `Y8b `888' `Y8b d8P' `Y8 `888' `888' `Y8b `888' .888. `888b. `8' 169 | 888 888 888 888 Y88bo. 888 888 888 888 .8"888. 8 `88b. 8 170 | 888 888 888oooo888' `"Y8888o. 888 888 888 888 .8' `888. 8 `88b. 8 171 | 888 888 888 `88b `"Y88b 888 888 888 888 .88ooo8888. 8 `88b.8 172 | `88b d88' 888 .88P oo .d8P 888 888 d88' 888 .8' `888. 8 `888 173 | `Y8bood8P' o888bood8P' 8""88888P' o888o o888bood8P' o888o o88o o8888o o8o `8 174 | 175 | "#; 176 | 177 | println!("{}", logo); 178 | 179 | #[cfg(debug_assertions)] 180 | println!( 181 | " 🚧 {}: dev [{} + {}]", 182 | "Mode".green().bold(), 183 | "unoptimized".red().bold(), 184 | "debuginfo".blue().bold() 185 | ); 186 | 187 | #[cfg(not(debug_assertions))] 188 | println!( 189 | " 🚀 {}: release [{}]", 190 | "Mode".green().bold(), 191 | "optimized".green().bold(), 192 | ); 193 | 194 | println!( 195 | " 🔧 {}: {}", 196 | "Version".green().bold(), 197 | env!("CARGO_PKG_VERSION") 198 | ); 199 | 200 | println!(" 🎉 {}: http://{}\n", "Served at".green().bold(), addr); 201 | 202 | server.await.map_err(|_| println!("Server error")).unwrap(); 203 | } 204 | } 205 | 206 | #[derive(Clone)] 207 | struct AppServer { 208 | router: Router, 209 | } 210 | 211 | impl AppServer { 212 | pub async fn resolve_endpoint( 213 | req: Request, 214 | route_value: Option, 215 | app_state: Option, 216 | ) -> Result, hyper::Error> 217 | where 218 | T: Send + Sync + 'static, 219 | { 220 | match route_value { 221 | Some(route_value) => { 222 | let route = match route_value.get_route(req.method()) { 223 | Some(r) => r, 224 | None => return Ok::<_, hyper::Error>(page_not_found()), 225 | }; 226 | let middlewares = route_value.get_middlewares(); 227 | let params = route_value.get_params(); 228 | let mut context = Context::new(req, params); 229 | let executor = EndpointExecutor::new(&route.handler, middlewares); 230 | 231 | if let Some(state) = app_state { 232 | context.add::(state); 233 | } 234 | 235 | let route_result = executor.next(context).await; 236 | 237 | let route_response = match route_result { 238 | Ok(ctx) => { 239 | let mut res = Response::builder(); 240 | if let Some(response) = ctx.take_response() { 241 | if let Some(headers) = response.headers() { 242 | if let Some(response_headers) = res.headers_mut() { 243 | headers.iter().for_each(move |(key, value)| { 244 | response_headers 245 | .insert(key, header::HeaderValue::from_static(value)); 246 | }); 247 | } 248 | } 249 | res.status(response.status()).body(response.body()) 250 | } else { 251 | // No response found 252 | res.status(StatusCode::OK).body(Body::from("")) 253 | } 254 | } 255 | Err(err) => { 256 | let body = Body::from(err.to_string()); 257 | Response::builder() 258 | .status(StatusCode::INTERNAL_SERVER_ERROR) 259 | .body(body) 260 | } 261 | }; 262 | 263 | Ok::<_, hyper::Error>(route_response.unwrap_or_else(|_| { 264 | internal_server_error(ObsidianError::GeneralError( 265 | "Error while constructing response body".to_string(), 266 | )) 267 | })) 268 | } 269 | _ => Ok::<_, hyper::Error>(page_not_found()), 270 | } 271 | } 272 | } 273 | 274 | fn page_not_found() -> Response { 275 | let mut server_response = Response::new(Body::from("404 Not Found")); 276 | *server_response.status_mut() = StatusCode::NOT_FOUND; 277 | 278 | server_response 279 | } 280 | 281 | fn internal_server_error(err: impl std::error::Error) -> Response { 282 | let body = Body::from(err.to_string()); 283 | Response::builder() 284 | .status(StatusCode::INTERNAL_SERVER_ERROR) 285 | .body(body) 286 | .unwrap() 287 | } 288 | 289 | pub struct EndpointExecutor<'a> { 290 | pub route_endpoint: &'a Arc, 291 | pub middleware: &'a [Arc], 292 | } 293 | 294 | impl<'a> EndpointExecutor<'a> { 295 | pub fn new( 296 | route_endpoint: &'a Arc, 297 | middleware: &'a [Arc], 298 | ) -> Self { 299 | EndpointExecutor { 300 | route_endpoint, 301 | middleware, 302 | } 303 | } 304 | 305 | pub async fn next(mut self, context: Context) -> ContextResult { 306 | if let Some((current, all_next)) = self.middleware.split_first() { 307 | self.middleware = all_next; 308 | current.handle(context, self).await 309 | } else { 310 | self.route_endpoint.call(context).await 311 | } 312 | } 313 | } 314 | 315 | #[cfg(test)] 316 | mod test { 317 | use super::*; 318 | use crate::context::Context; 319 | use async_std::task; 320 | use hyper::{body, body::Buf, StatusCode}; 321 | 322 | #[test] 323 | fn test_app_server_resolve_endpoint() { 324 | task::block_on(async { 325 | let mut router = Router::new(); 326 | 327 | router.get("/", |mut ctx: Context| async move { 328 | let body = ctx.take_body(); 329 | 330 | let request_body = match body::aggregate(body).await { 331 | Ok(buf) => String::from_utf8(buf.chunk().to_vec()), 332 | _ => { 333 | panic!(); 334 | } 335 | }; 336 | 337 | assert_eq!(ctx.uri().path(), "/"); 338 | assert_eq!(request_body.unwrap(), "test_app_server"); 339 | ctx.build("test_app_server").ok() 340 | }); 341 | 342 | let app_server = AppServer { router }; 343 | 344 | let req_builder = Request::builder(); 345 | 346 | let req = req_builder 347 | .uri("/") 348 | .body(Body::from("test_app_server")) 349 | .unwrap(); 350 | 351 | let route_value = app_server.router.search_route(req.uri().path()); 352 | let actual_response = 353 | AppServer::resolve_endpoint::(req, route_value, None) 354 | .await 355 | .unwrap(); 356 | 357 | let mut expected_response = Response::new(Body::from("test_app_server")); 358 | *expected_response.status_mut() = StatusCode::OK; 359 | 360 | assert_eq!(actual_response.status(), expected_response.status()); 361 | 362 | let actual_res_body = match body::aggregate(actual_response).await { 363 | Ok(buf) => String::from_utf8(buf.chunk().to_vec()), 364 | _ => panic!(), 365 | }; 366 | 367 | let expected_res_body = match body::aggregate(expected_response).await { 368 | Ok(buf) => String::from_utf8(buf.chunk().to_vec()), 369 | _ => panic!(), 370 | }; 371 | 372 | assert_eq!(actual_res_body.unwrap(), expected_res_body.unwrap()); 373 | }) 374 | } 375 | } 376 | -------------------------------------------------------------------------------- /src/context.rs: -------------------------------------------------------------------------------- 1 | use http::Extensions; 2 | use hyper::{body, body::Buf}; 3 | use serde::de::DeserializeOwned; 4 | use serde::ser::Serialize; 5 | use url::form_urlencoded; 6 | 7 | use std::borrow::Cow; 8 | use std::collections::HashMap; 9 | use std::convert::From; 10 | use std::str::FromStr; 11 | 12 | use crate::router::{from_cow_map, ContextResult, Responder, Response}; 13 | use crate::ObsidianError; 14 | use crate::{ 15 | header::{HeaderName, HeaderValue}, 16 | Body, HeaderMap, Method, Request, StatusCode, Uri, 17 | }; 18 | 19 | /// Context contains the data for current http connection context. 20 | /// For example, request information, params, method, and path. 21 | #[derive(Debug)] 22 | pub struct Context { 23 | request: Request, 24 | params_data: HashMap, 25 | response: Option, 26 | } 27 | 28 | impl Context { 29 | pub fn new(request: Request, params_data: HashMap) -> Self { 30 | Self { 31 | request, 32 | params_data, 33 | response: None, 34 | } 35 | } 36 | 37 | /// Access request headers 38 | pub fn headers(&self) -> &HeaderMap { 39 | self.request.headers() 40 | } 41 | 42 | /// Access mutable request header 43 | pub fn headers_mut(&mut self) -> &mut HeaderMap { 44 | self.request.headers_mut() 45 | } 46 | 47 | /// Access request method 48 | pub fn method(&self) -> &Method { 49 | self.request.method() 50 | } 51 | 52 | /// Access request uri 53 | pub fn uri(&self) -> &Uri { 54 | self.request.uri() 55 | } 56 | 57 | /// Access request extensions 58 | pub fn extensions(&self) -> &Extensions { 59 | self.request.extensions() 60 | } 61 | 62 | /// Access mutable request extensions 63 | pub fn extensions_mut(&mut self) -> &mut Extensions { 64 | self.request.extensions_mut() 65 | } 66 | 67 | /// Add dynamic data into request extensions 68 | pub fn add(&mut self, ctx_data: T) { 69 | self.extensions_mut().insert(ctx_data); 70 | } 71 | 72 | /// Get dynamic data from request extensions 73 | pub fn get(&self) -> Option<&T> { 74 | self.extensions().get::() 75 | } 76 | 77 | /// Get mutable dynamic data from request extensions 78 | pub fn get_mut(&mut self) -> Option<&mut T> { 79 | self.extensions_mut().get_mut::() 80 | } 81 | 82 | /// Method to get the params value according to key. 83 | /// Panic if key is not found. 84 | /// 85 | /// # Example 86 | /// 87 | /// ``` 88 | /// # use obsidian::StatusCode; 89 | /// # use obsidian::ContextResult; 90 | /// # use obsidian::context::Context; 91 | /// 92 | /// // Assumming ctx contains params for id and mode 93 | /// async fn get_handler(ctx: Context) -> ContextResult { 94 | /// let id: i32 = ctx.param("id")?; 95 | /// let mode: String = ctx.param("mode")?; 96 | /// 97 | /// assert_eq!(id, 1); 98 | /// assert_eq!(mode, "edit".to_string()); 99 | /// 100 | /// ctx.build("").ok() 101 | /// } 102 | /// 103 | /// ``` 104 | /// 105 | pub fn param(&self, key: &str) -> Result { 106 | self.params_data 107 | .get(key) 108 | .ok_or(ObsidianError::NoneError)? 109 | .parse() 110 | .map_err(|_err| ObsidianError::ParamError(format!("Failed to parse param {}", key))) 111 | } 112 | 113 | /// Method to get the string query data from the request url. 114 | /// Untagged is not supported 115 | /// 116 | /// # Example 117 | /// ``` 118 | /// # use serde::*; 119 | /// 120 | /// # use obsidian::context::Context; 121 | /// # use obsidian::{ContextResult, StatusCode}; 122 | /// 123 | /// #[derive(Deserialize, Serialize, Debug)] 124 | /// struct QueryString { 125 | /// id: i32, 126 | /// mode: String, 127 | /// } 128 | /// 129 | /// // Assume ctx contains string query with data {id=1&mode=edit} 130 | /// async fn get_handler(mut ctx: Context) -> ContextResult { 131 | /// let result: QueryString = ctx.query_params()?; 132 | /// 133 | /// assert_eq!(result.id, 1); 134 | /// assert_eq!(result.mode, "edit".to_string()); 135 | /// 136 | /// ctx.build("").ok() 137 | /// } 138 | /// ``` 139 | pub fn query_params(&mut self) -> Result { 140 | let query = match self.uri().query() { 141 | Some(query) => query, 142 | _ => "", 143 | } 144 | .as_bytes(); 145 | 146 | Self::parse_queries(&query) 147 | } 148 | 149 | /// Method to get the forms query data from the request body. 150 | /// Body is consumed after calling this method. 151 | /// Untagged is not supported 152 | /// 153 | /// # Example 154 | /// ``` 155 | /// # use serde::*; 156 | /// 157 | /// # use obsidian::context::Context; 158 | /// # use obsidian::{ContextResult, StatusCode}; 159 | /// 160 | /// #[derive(Deserialize, Serialize, Debug)] 161 | /// struct FormResult { 162 | /// id: i32, 163 | /// mode: String, 164 | /// } 165 | /// 166 | /// // Assume ctx contains form query with data {id=1&mode=edit} 167 | /// async fn get_handler(mut ctx: Context) -> ContextResult { 168 | /// let result: FormResult = ctx.form().await?; 169 | /// 170 | /// assert_eq!(result.id, 1); 171 | /// assert_eq!(result.mode, "edit".to_string()); 172 | /// 173 | /// ctx.build("").ok() 174 | /// } 175 | /// ``` 176 | pub async fn form(&mut self) -> Result { 177 | let body = self.take_body(); 178 | 179 | let buf = match body::aggregate(body).await { 180 | Ok(buf) => buf, 181 | _ => { 182 | return Err(ObsidianError::NoneError); 183 | } 184 | }; 185 | 186 | Self::parse_queries(buf.chunk()) 187 | } 188 | 189 | /// Form value merge with Params 190 | pub fn form_with_param(&mut self) -> Result { 191 | unimplemented!() 192 | } 193 | 194 | /// Method to get the json data from the request body. Body is consumed after calling this method. 195 | /// The result can be either handled by using static type or dynamic map. 196 | /// Panic if parsing fail. 197 | /// 198 | /// # Example 199 | /// 200 | /// ### Handle by static type 201 | /// ``` 202 | /// # use serde::*; 203 | /// 204 | /// # use obsidian::context::Context; 205 | /// # use obsidian::{ContextResult, StatusCode}; 206 | /// 207 | /// #[derive(Deserialize, Serialize, Debug)] 208 | /// struct JsonResult { 209 | /// id: i32, 210 | /// mode: String, 211 | /// } 212 | /// 213 | /// // Assume ctx contains json with data {id:1, mode:'edit'} 214 | /// async fn get_handler(mut ctx: Context) -> ContextResult { 215 | /// let result: JsonResult = ctx.json().await?; 216 | /// 217 | /// assert_eq!(result.id, 1); 218 | /// assert_eq!(result.mode, "edit".to_string()); 219 | /// 220 | /// ctx.build("").ok() 221 | /// } 222 | /// ``` 223 | /// 224 | /// ### Handle by dynamic map 225 | /// ``` 226 | /// # use serde_json::Value; 227 | /// 228 | /// # use obsidian::context::Context; 229 | /// # use obsidian::{ContextResult, StatusCode}; 230 | /// 231 | /// // Assume ctx contains json with data {id:1, mode:'edit'} 232 | /// async fn get_handler(mut ctx: Context) -> ContextResult { 233 | /// let result: serde_json::Value = ctx.json().await?; 234 | /// 235 | /// assert_eq!(result["id"], 1); 236 | /// assert_eq!(result["mode"], "edit".to_string()); 237 | /// 238 | /// ctx.build("").ok() 239 | /// } 240 | /// ``` 241 | pub async fn json(&mut self) -> Result { 242 | let body = self.take_body(); 243 | 244 | let buf = match body::aggregate(body).await { 245 | Ok(buf) => buf, 246 | _ => { 247 | return Err(ObsidianError::NoneError); 248 | } 249 | }; 250 | 251 | Ok(serde_json::from_slice(buf.chunk())?) 252 | } 253 | 254 | /// Json value merged with Params 255 | pub fn json_with_param(&mut self) -> Result { 256 | unimplemented!() 257 | } 258 | 259 | /// Consumes body of the request and replace it with empty body. 260 | pub fn take_body(&mut self) -> Body { 261 | std::mem::replace(self.request.body_mut(), Body::empty()) 262 | } 263 | 264 | /// Take response 265 | pub fn take_response(self) -> Option { 266 | self.response 267 | } 268 | 269 | pub fn response(&self) -> &Option { 270 | &self.response 271 | } 272 | 273 | pub fn response_mut(&mut self) -> &mut Option { 274 | &mut self.response 275 | } 276 | 277 | /// Build any kind of response which implemented Responder trait 278 | pub fn build(self, res: impl Responder) -> ResponseBuilder { 279 | ResponseBuilder::new(self, res.respond_to()) 280 | } 281 | 282 | /// Build data into json format. The data must implement Serialize trait 283 | pub fn build_json(self, body: impl Serialize) -> ResponseBuilder { 284 | ResponseBuilder::new(self, Response::ok().json(body)) 285 | } 286 | 287 | /// Build response from static file. 288 | pub async fn build_file(self, file_path: &str) -> ResponseBuilder { 289 | ResponseBuilder::new(self, Response::ok().file(file_path).await) 290 | } 291 | 292 | fn parse_queries(query: &[u8]) -> Result { 293 | let mut parsed_form_map: HashMap> = HashMap::default(); 294 | let mut cow_form_map = HashMap::, Cow<[String]>>::default(); 295 | 296 | // Parse and merge chunks with same name key 297 | form_urlencoded::parse(query) 298 | .into_owned() 299 | .for_each(|(key, val)| { 300 | if !val.is_empty() { 301 | parsed_form_map 302 | .entry(key) 303 | .or_insert_with(Vec::new) 304 | .push(val); 305 | } 306 | }); 307 | 308 | // Wrap vec with cow pointer 309 | parsed_form_map.iter().for_each(|(key, val)| { 310 | cow_form_map 311 | .entry(std::borrow::Cow::from(key)) 312 | .or_insert_with(|| std::borrow::Cow::from(val)); 313 | }); 314 | 315 | Ok(from_cow_map(&cow_form_map)?) 316 | } 317 | } 318 | 319 | pub struct ResponseBuilder { 320 | ctx: Context, 321 | response: Response, 322 | } 323 | 324 | impl ResponseBuilder { 325 | pub fn new(ctx: Context, response: Response) -> Self { 326 | ResponseBuilder { ctx, response } 327 | } 328 | 329 | /// set http status code for response 330 | pub fn with_status(mut self, status: StatusCode) -> Self { 331 | self.response = self.response.set_status(status); 332 | self 333 | } 334 | 335 | /// set http header for response 336 | pub fn with_header(mut self, key: HeaderName, value: &'static str) -> Self { 337 | self.response = self.response.set_header(key, value); 338 | self 339 | } 340 | 341 | /// set custom http header for response with `&str` key 342 | pub fn with_header_str(mut self, key: &'static str, value: &'static str) -> Self { 343 | self.response = self.response.set_header_str(key, value); 344 | self 345 | } 346 | 347 | pub fn with_headers(mut self, headers: Vec<(HeaderName, &'static str)>) -> Self { 348 | self.response = self.response.set_headers(headers); 349 | self 350 | } 351 | 352 | pub fn with_headers_str(mut self, headers: Vec<(&'static str, &'static str)>) -> Self { 353 | self.response = self.response.set_headers_str(headers); 354 | self 355 | } 356 | 357 | pub fn ok(mut self) -> ContextResult { 358 | *self.ctx.response_mut() = Some(self.response); 359 | Ok(self.ctx) 360 | } 361 | } 362 | 363 | #[cfg(test)] 364 | mod test { 365 | use super::*; 366 | use async_std::task; 367 | use hyper::{Body, Request}; 368 | use serde::*; 369 | use serde_json::json; 370 | 371 | #[derive(Deserialize, Serialize, Debug, PartialEq)] 372 | struct FormResult { 373 | id: i32, 374 | mode: String, 375 | } 376 | 377 | #[derive(Deserialize, Serialize, Debug, PartialEq)] 378 | struct FormExtraResult { 379 | id: i32, 380 | mode: String, 381 | #[serde(default)] 382 | extra: i32, 383 | } 384 | 385 | #[derive(Deserialize, Serialize, Debug, PartialEq)] 386 | struct JsonResult { 387 | id: i32, 388 | mode: String, 389 | } 390 | 391 | #[derive(Deserialize, Serialize, Debug, PartialEq)] 392 | struct JsonExtraResult { 393 | id: i32, 394 | mode: String, 395 | #[serde(default)] 396 | extra: i32, 397 | } 398 | 399 | #[test] 400 | fn test_params() -> Result<(), ObsidianError> { 401 | let mut params_map = HashMap::default(); 402 | 403 | params_map.insert("id".to_string(), "1".to_string()); 404 | params_map.insert("mode".to_string(), "edit".to_string()); 405 | 406 | let request = Request::new(Body::from("")); 407 | 408 | let ctx = Context::new(request, params_map); 409 | 410 | let id: i32 = ctx.param("id")?; 411 | let mode: String = ctx.param("mode")?; 412 | 413 | assert_eq!(id, 1); 414 | assert_eq!(mode, "edit".to_string()); 415 | 416 | Ok(()) 417 | } 418 | 419 | #[test] 420 | #[should_panic] 421 | fn test_params_without_value() { 422 | let mut params_map = HashMap::default(); 423 | 424 | params_map.insert("mode".to_string(), "edit".to_string()); 425 | 426 | let request = Request::new(Body::from("")); 427 | 428 | let ctx = Context::new(request, params_map); 429 | 430 | let _mode: String = ctx.param("mode").unwrap(); 431 | let _id: i32 = ctx.param("id").unwrap(); 432 | } 433 | 434 | #[test] 435 | fn test_string_query() -> Result<(), ObsidianError> { 436 | let params_map = HashMap::default(); 437 | 438 | let mut request = Request::new(Body::from("")); 439 | *request.uri_mut() = Uri::from_str("/test/test?id=1&mode=edit").unwrap(); 440 | 441 | let mut ctx = Context::new(request, params_map); 442 | 443 | let actual_result: FormResult = ctx.query_params()?; 444 | let expected_result = FormResult { 445 | id: 1, 446 | mode: "edit".to_string(), 447 | }; 448 | 449 | assert_eq!(actual_result, expected_result); 450 | Ok(()) 451 | } 452 | 453 | #[test] 454 | fn test_form() -> Result<(), ObsidianError> { 455 | task::block_on(async { 456 | let params = HashMap::default(); 457 | let request = Request::new(Body::from("id=1&mode=edit")); 458 | 459 | let mut ctx = Context::new(request, params); 460 | 461 | let actual_result: FormResult = ctx.form().await?; 462 | let expected_result = FormResult { 463 | id: 1, 464 | mode: "edit".to_string(), 465 | }; 466 | 467 | assert_eq!(actual_result, expected_result); 468 | Ok(()) 469 | }) 470 | } 471 | 472 | #[test] 473 | fn test_form_with_extra_body() -> Result<(), ObsidianError> { 474 | task::block_on(async { 475 | let params = HashMap::default(); 476 | let request = Request::new(Body::from("id=1&mode=edit&extra=true")); 477 | 478 | let mut ctx = Context::new(request, params); 479 | 480 | let actual_result: FormResult = ctx.form().await?; 481 | let expected_result = FormResult { 482 | id: 1, 483 | mode: "edit".to_string(), 484 | }; 485 | 486 | assert_eq!(actual_result, expected_result); 487 | Ok(()) 488 | }) 489 | } 490 | 491 | #[test] 492 | fn test_form_with_extra_field() -> Result<(), ObsidianError> { 493 | task::block_on(async { 494 | let params = HashMap::default(); 495 | let request = Request::new(Body::from("id=1&mode=edit")); 496 | 497 | let mut ctx = Context::new(request, params); 498 | 499 | let actual_result: FormExtraResult = ctx.form().await?; 500 | let expected_result = FormExtraResult { 501 | id: 1, 502 | mode: "edit".to_string(), 503 | extra: i32::default(), 504 | }; 505 | 506 | assert_eq!(actual_result, expected_result); 507 | Ok(()) 508 | }) 509 | } 510 | 511 | #[test] 512 | fn test_json_struct() -> Result<(), ObsidianError> { 513 | task::block_on(async { 514 | let params = HashMap::default(); 515 | let request = Request::new(Body::from("{\"id\":1,\"mode\":\"edit\"}")); 516 | 517 | let mut ctx = Context::new(request, params); 518 | 519 | let actual_result: JsonResult = ctx.json().await?; 520 | let expected_result = JsonResult { 521 | id: 1, 522 | mode: "edit".to_string(), 523 | }; 524 | 525 | assert_eq!(actual_result, expected_result); 526 | Ok(()) 527 | }) 528 | } 529 | 530 | #[test] 531 | fn test_json_value() -> Result<(), ObsidianError> { 532 | task::block_on(async { 533 | let params = HashMap::default(); 534 | let request = Request::new(Body::from("{\"id\":1,\"mode\":\"edit\"}")); 535 | 536 | let mut ctx = Context::new(request, params); 537 | 538 | let actual_result: serde_json::Value = ctx.json().await?; 539 | 540 | assert_eq!(actual_result["id"], json!(1)); 541 | assert_eq!(actual_result["mode"], json!("edit")); 542 | Ok(()) 543 | }) 544 | } 545 | } 546 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | mod obsidian_error; 2 | 3 | pub use obsidian_error::ObsidianError; 4 | -------------------------------------------------------------------------------- /src/error/obsidian_error.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt; 3 | use std::fmt::Display; 4 | 5 | use serde_json::error::Error as JsonError; 6 | 7 | use crate::router::FormError; 8 | 9 | /// Errors occurs in Obsidian framework 10 | #[derive(Debug)] 11 | pub enum ObsidianError { 12 | ParamError(String), 13 | JsonError(JsonError), 14 | FormError(FormError), 15 | GeneralError(String), 16 | NoneError, 17 | } 18 | 19 | impl Display for ObsidianError { 20 | fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 21 | let error_msg = match *self { 22 | ObsidianError::ParamError(ref msg) => msg.to_string(), 23 | ObsidianError::JsonError(ref err) => err.to_string(), 24 | ObsidianError::FormError(ref err) => err.to_string(), 25 | ObsidianError::GeneralError(ref msg) => msg.to_string(), 26 | ObsidianError::NoneError => "Input should not be None".to_string(), 27 | }; 28 | 29 | formatter.write_str(&error_msg) 30 | } 31 | } 32 | 33 | impl From for ObsidianError { 34 | fn from(error: FormError) -> Self { 35 | ObsidianError::FormError(error) 36 | } 37 | } 38 | 39 | impl From for ObsidianError { 40 | fn from(error: JsonError) -> Self { 41 | ObsidianError::JsonError(error) 42 | } 43 | } 44 | 45 | impl Error for ObsidianError { 46 | fn description(&self) -> &str { 47 | "Obsidian Error" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //#[deny(missing_docs)] 2 | 3 | mod app; 4 | pub mod error; 5 | 6 | pub mod context; 7 | pub mod middleware; 8 | pub mod router; 9 | 10 | pub use app::{App, EndpointExecutor}; 11 | pub use error::ObsidianError; 12 | pub use hyper::{header, Body, HeaderMap, Method, Request, Response, StatusCode, Uri, Version}; 13 | pub use router::ContextResult; 14 | -------------------------------------------------------------------------------- /src/middleware.rs: -------------------------------------------------------------------------------- 1 | pub mod logger; 2 | 3 | use async_trait::async_trait; 4 | 5 | use crate::app::EndpointExecutor; 6 | use crate::context::Context; 7 | use crate::router::ContextResult; 8 | 9 | #[async_trait] 10 | pub trait Middleware: Send + Sync + 'static { 11 | async fn handle<'a>( 12 | &'a self, 13 | context: Context, 14 | ep_executor: EndpointExecutor<'a>, 15 | ) -> ContextResult; 16 | } 17 | -------------------------------------------------------------------------------- /src/middleware/logger.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use std::time::Instant; 3 | 4 | use crate::app::EndpointExecutor; 5 | use crate::context::Context; 6 | use crate::middleware::Middleware; 7 | use crate::router::ContextResult; 8 | 9 | use colored::*; 10 | 11 | #[derive(Default)] 12 | pub struct Logger {} 13 | 14 | impl Logger { 15 | pub fn new() -> Self { 16 | Logger {} 17 | } 18 | } 19 | 20 | #[async_trait] 21 | impl Middleware for Logger { 22 | async fn handle<'a>( 23 | &'a self, 24 | context: Context, 25 | ep_executor: EndpointExecutor<'a>, 26 | ) -> ContextResult { 27 | let start = Instant::now(); 28 | println!("[info] {} {}", context.method(), context.uri(),); 29 | 30 | #[cfg(debug_assertions)] 31 | println!("{} {:#?}", "[debug]".cyan(), context); 32 | 33 | match ep_executor.next(context).await { 34 | Ok(context_after) => { 35 | let duration = start.elapsed(); 36 | println!( 37 | "[info] Sent {} in {:?}", 38 | context_after.response().as_ref().unwrap().status(), 39 | duration 40 | ); 41 | Ok(context_after) 42 | } 43 | Err(error) => { 44 | println!("{} {}", "[error]".red(), error); 45 | Err(error) 46 | } 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/router.rs: -------------------------------------------------------------------------------- 1 | mod handler; 2 | mod req_deserializer; 3 | mod resource; 4 | mod responder; 5 | mod response; 6 | mod response_body; 7 | mod route; 8 | mod route_trie; 9 | 10 | use self::route_trie::RouteTrie; 11 | use crate::context::Context; 12 | use crate::middleware::Middleware; 13 | use crate::Method; 14 | pub use hyper::header; 15 | 16 | pub use self::handler::{ContextResult, Handler}; 17 | pub use self::req_deserializer::{from_cow_map, Error as FormError}; 18 | pub use self::resource::Resource; 19 | pub use self::responder::Responder; 20 | pub use self::response::Response; 21 | pub use self::response_body::ResponseBody; 22 | pub use self::route::Route; 23 | 24 | pub(crate) use self::route_trie::RouteValueResult; 25 | 26 | pub struct Router { 27 | routes: RouteTrie, 28 | } 29 | 30 | impl Clone for Router { 31 | fn clone(&self) -> Self { 32 | Router { 33 | routes: self.routes.clone(), 34 | } 35 | } 36 | } 37 | 38 | impl Default for Router { 39 | fn default() -> Self { 40 | Self::new() 41 | } 42 | } 43 | 44 | impl Router { 45 | pub fn new() -> Self { 46 | Router { 47 | routes: RouteTrie::new(), 48 | } 49 | } 50 | 51 | pub fn get(&mut self, path: &str, handler: impl Handler) { 52 | self.insert_route(Method::GET, path, handler); 53 | } 54 | 55 | pub fn post(&mut self, path: &str, handler: impl Handler) { 56 | self.insert_route(Method::POST, path, handler); 57 | } 58 | 59 | pub fn put(&mut self, path: &str, handler: impl Handler) { 60 | self.insert_route(Method::PUT, path, handler); 61 | } 62 | 63 | pub fn patch(&mut self, path: &str, handler: impl Handler) { 64 | self.insert_route(Method::PATCH, path, handler); 65 | } 66 | 67 | pub fn delete(&mut self, path: &str, handler: impl Handler) { 68 | self.insert_route(Method::DELETE, path, handler); 69 | } 70 | 71 | /// Apply middleware in the provided route 72 | pub fn use_service_to(&mut self, path: &str, middleware: impl Middleware) { 73 | self.routes.insert_middleware(path, middleware); 74 | } 75 | 76 | /// Apply middleware in current relative route 77 | pub fn use_service(&mut self, middleware: impl Middleware) { 78 | self.routes.insert_default_middleware(middleware); 79 | } 80 | 81 | /// Serve static files by the virtual path as the route and directory path as the server file path 82 | pub fn use_static_to(&mut self, virtual_path: &str, dir_path: &str) { 83 | let mut path = String::from(virtual_path); 84 | path.push_str("/*"); 85 | 86 | self.get( 87 | &path, 88 | Self::static_virtual_file_handler(virtual_path, dir_path), 89 | ); 90 | } 91 | 92 | /// Serve static files by the directory path as the route and server file path 93 | pub fn use_static(&mut self, dir_path: &str) { 94 | let mut path = String::from(dir_path); 95 | path.push_str("/*"); 96 | 97 | self.get(&path, Self::static_dir_file_handler); 98 | } 99 | 100 | /// Apply route handler in current relative route 101 | pub fn use_router(&mut self, path: &str, other: Router) { 102 | RouteTrie::insert_sub_route(&mut self.routes, path, other.routes); 103 | } 104 | 105 | pub fn search_route(&self, path: &str) -> Option { 106 | self.routes.search_route(path) 107 | } 108 | 109 | fn insert_route(&mut self, method: Method, path: &str, handler: impl Handler) { 110 | let route = Route::new(method, handler); 111 | 112 | self.routes.insert_route(path, route); 113 | } 114 | 115 | fn static_virtual_file_handler(virtual_path: &str, dir_path: &str) -> impl Handler { 116 | let dir_path = dir_path 117 | .split('/') 118 | .filter(|key| !key.is_empty()) 119 | .map(|x| x.to_string()) 120 | .collect::>(); 121 | 122 | let virtual_path_len = virtual_path 123 | .split('/') 124 | .filter(|key| !key.is_empty()) 125 | .count(); 126 | 127 | move |ctx: Context| { 128 | let mut dir_path = dir_path.clone(); 129 | let mut relative_path = ctx 130 | .uri() 131 | .path() 132 | .split('/') 133 | .filter(|key| !key.is_empty()) 134 | .skip(virtual_path_len) 135 | .map(|x| x.to_string()) 136 | .collect::>(); 137 | 138 | dir_path.append(&mut relative_path); 139 | 140 | Box::pin(async move { 141 | ctx.build(Response::ok().file(&dir_path.join("/")).await) 142 | .ok() 143 | }) 144 | } 145 | } 146 | 147 | async fn static_dir_file_handler(ctx: Context) -> ContextResult { 148 | let relative_path = ctx 149 | .uri() 150 | .path() 151 | .split('/') 152 | .filter(|key| !key.is_empty()) 153 | .map(|x| x.to_string()) 154 | .collect::>(); 155 | 156 | ctx.build(Response::ok().file(&relative_path.join("/")).await) 157 | .ok() 158 | } 159 | } 160 | 161 | #[cfg(test)] 162 | mod tests { 163 | use super::*; 164 | use crate::context::Context; 165 | use crate::middleware::logger::Logger; 166 | 167 | async fn handler(ctx: Context) -> ContextResult { 168 | ctx.build("test").ok() 169 | } 170 | 171 | #[test] 172 | fn router_get_test() { 173 | let mut router = Router::new(); 174 | 175 | router.get("router/test", handler); 176 | 177 | let result = router.search_route("router/test"); 178 | let fail_result = router.search_route("failed"); 179 | 180 | assert!(result.is_some()); 181 | assert!(fail_result.is_none()); 182 | 183 | match result { 184 | Some(route) => { 185 | let middlewares = route.get_middlewares(); 186 | let route_value = route.get_route(&Method::GET).unwrap(); 187 | 188 | assert_eq!(middlewares.len(), 0); 189 | assert_eq!(route_value.method, Method::GET); 190 | } 191 | _ => panic!(), 192 | } 193 | } 194 | 195 | #[test] 196 | fn router_post_test() { 197 | let mut router = Router::new(); 198 | 199 | router.post("router/test", handler); 200 | 201 | let result = router.search_route("router/test"); 202 | let fail_result = router.search_route("failed"); 203 | 204 | assert!(result.is_some()); 205 | assert!(fail_result.is_none()); 206 | 207 | match result { 208 | Some(route) => { 209 | let middlewares = route.get_middlewares(); 210 | let route_value = route.get_route(&Method::POST).unwrap(); 211 | 212 | assert_eq!(middlewares.len(), 0); 213 | assert_eq!(route_value.method, Method::POST); 214 | } 215 | _ => panic!(), 216 | } 217 | } 218 | 219 | #[test] 220 | fn router_put_test() { 221 | let mut router = Router::new(); 222 | 223 | router.put("router/test", handler); 224 | 225 | let result = router.search_route("router/test"); 226 | let fail_result = router.search_route("failed"); 227 | 228 | assert!(result.is_some()); 229 | assert!(fail_result.is_none()); 230 | 231 | match result { 232 | Some(route) => { 233 | let middlewares = route.get_middlewares(); 234 | let route_value = route.get_route(&Method::PUT).unwrap(); 235 | 236 | assert_eq!(middlewares.len(), 0); 237 | assert_eq!(route_value.method, Method::PUT); 238 | } 239 | _ => panic!(), 240 | } 241 | } 242 | 243 | #[test] 244 | fn router_delete_test() { 245 | let mut router = Router::new(); 246 | 247 | router.delete("router/test", handler); 248 | 249 | let result = router.search_route("router/test"); 250 | let fail_result = router.search_route("failed"); 251 | 252 | assert!(result.is_some()); 253 | assert!(fail_result.is_none()); 254 | 255 | match result { 256 | Some(route) => { 257 | let middlewares = route.get_middlewares(); 258 | let route_value = route.get_route(&Method::DELETE).unwrap(); 259 | 260 | assert_eq!(middlewares.len(), 0); 261 | assert_eq!(route_value.method, Method::DELETE); 262 | } 263 | _ => panic!(), 264 | } 265 | } 266 | 267 | #[test] 268 | fn router_root_middleware_test() { 269 | let mut router = Router::new(); 270 | let logger = Logger::new(); 271 | 272 | router.use_service(logger); 273 | 274 | let result = router.search_route("/"); 275 | let fail_result = router.search_route("failed"); 276 | 277 | assert!(result.is_some()); 278 | assert!(fail_result.is_none()); 279 | 280 | match result { 281 | Some(route) => { 282 | let middlewares = route.get_middlewares(); 283 | 284 | assert_eq!(middlewares.len(), 1); 285 | } 286 | _ => panic!(), 287 | } 288 | } 289 | 290 | #[test] 291 | fn router_relative_middleware_test() { 292 | let mut router = Router::new(); 293 | let logger = Logger::new(); 294 | 295 | router.use_service_to("middleware/child", logger); 296 | 297 | let result = router.search_route("/middleware/child"); 298 | let fail_result = router.search_route("/"); 299 | 300 | assert!(result.is_some()); 301 | assert!(fail_result.is_none()); 302 | 303 | match result { 304 | Some(route) => { 305 | let middlewares = route.get_middlewares(); 306 | 307 | assert_eq!(middlewares.len(), 1); 308 | } 309 | _ => panic!(), 310 | } 311 | } 312 | 313 | #[test] 314 | fn router_search_test() { 315 | let mut router = Router::new(); 316 | 317 | router.get("router/test", handler); 318 | router.post("router/test", handler); 319 | router.put("router/test", handler); 320 | router.delete("router/test", handler); 321 | 322 | router.get("route/diff_route", handler); 323 | 324 | let result = router.search_route("router/test"); 325 | let diff_result = router.search_route("route/diff_route"); 326 | let fail_result = router.search_route("failed"); 327 | 328 | assert!(result.is_some()); 329 | assert!(diff_result.is_some()); 330 | assert!(fail_result.is_none()); 331 | 332 | match result { 333 | Some(route) => { 334 | let middlewares = route.get_middlewares(); 335 | let route_value = route.get_route(&Method::GET).unwrap(); 336 | 337 | assert_eq!(middlewares.len(), 0); 338 | assert_eq!(route_value.method, Method::GET); 339 | 340 | let route_value = route.get_route(&Method::POST).unwrap(); 341 | 342 | assert_eq!(middlewares.len(), 0); 343 | assert_eq!(route_value.method, Method::POST); 344 | 345 | let route_value = route.get_route(&Method::PUT).unwrap(); 346 | 347 | assert_eq!(middlewares.len(), 0); 348 | assert_eq!(route_value.method, Method::PUT); 349 | 350 | let route_value = route.get_route(&Method::DELETE).unwrap(); 351 | 352 | assert_eq!(middlewares.len(), 0); 353 | assert_eq!(route_value.method, Method::DELETE); 354 | } 355 | _ => panic!(), 356 | } 357 | 358 | match diff_result { 359 | Some(route) => { 360 | let middlewares = route.get_middlewares(); 361 | let route_value = route.get_route(&Method::GET).unwrap(); 362 | 363 | assert_eq!(middlewares.len(), 0); 364 | assert_eq!(route_value.method, Method::GET); 365 | } 366 | _ => panic!(), 367 | } 368 | } 369 | 370 | #[test] 371 | fn router_merge_test() { 372 | let mut main_router = Router::new(); 373 | let mut sub_router = Router::new(); 374 | 375 | main_router.get("router/test", handler); 376 | sub_router.get("router/test", handler); 377 | 378 | let logger = Logger::new(); 379 | 380 | sub_router.use_service(logger); 381 | 382 | main_router.use_router("sub_router", sub_router); 383 | 384 | let result = main_router.search_route("router/test"); 385 | let sub_result = main_router.search_route("sub_router/router/test"); 386 | let fail_result = main_router.search_route("failed"); 387 | 388 | assert!(result.is_some()); 389 | assert!(sub_result.is_some()); 390 | assert!(fail_result.is_none()); 391 | 392 | match result { 393 | Some(route) => { 394 | let middlewares = route.get_middlewares(); 395 | let route_value = route.get_route(&Method::GET).unwrap(); 396 | 397 | assert_eq!(middlewares.len(), 0); 398 | assert_eq!(route_value.method, Method::GET); 399 | } 400 | _ => panic!(), 401 | } 402 | 403 | match sub_result { 404 | Some(route) => { 405 | let middlewares = route.get_middlewares(); 406 | let route_value = route.get_route(&Method::GET).unwrap(); 407 | 408 | assert_eq!(middlewares.len(), 1); 409 | assert_eq!(route_value.method, Method::GET); 410 | } 411 | _ => panic!(), 412 | } 413 | } 414 | 415 | #[should_panic] 416 | #[test] 417 | fn router_duplicate_path_test() { 418 | let mut router = Router::new(); 419 | 420 | router.get("router/test", handler); 421 | router.get("router/test", handler); 422 | } 423 | 424 | #[should_panic] 425 | #[test] 426 | fn router_ambiguous_path_test() { 427 | let mut router = Router::new(); 428 | 429 | router.get("router/:test", handler); 430 | router.get("router/test", handler); 431 | } 432 | 433 | #[should_panic] 434 | #[test] 435 | fn router_duplicate_merge_test() { 436 | let mut main_router = Router::new(); 437 | let mut sub_router = Router::new(); 438 | 439 | main_router.get("sub_router/test", handler); 440 | sub_router.get("test", handler); 441 | 442 | let logger = Logger::new(); 443 | 444 | sub_router.use_service(logger); 445 | 446 | main_router.use_router("sub_router", sub_router); 447 | } 448 | } 449 | -------------------------------------------------------------------------------- /src/router/handler.rs: -------------------------------------------------------------------------------- 1 | use crate::context::Context; 2 | use crate::error::ObsidianError; 3 | 4 | use async_trait::async_trait; 5 | use std::future::Future; 6 | 7 | pub type ContextResult = Result; 8 | 9 | #[async_trait] 10 | pub trait Handler: Send + Sync + 'static { 11 | async fn call(&self, ctx: Context) -> ContextResult; 12 | } 13 | 14 | #[async_trait] 15 | impl Handler for T 16 | where 17 | T: Fn(Context) -> F + Send + Sync + 'static, 18 | F: Future + Send + 'static, 19 | { 20 | async fn call(&self, ctx: Context) -> ContextResult { 21 | (self)(ctx).await 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/router/req_deserializer.rs: -------------------------------------------------------------------------------- 1 | use serde::de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor}; 2 | use serde::forward_to_deserialize_any; 3 | use serde::ser; 4 | use serde::Deserialize; 5 | 6 | use std::borrow::Cow; 7 | use std::collections::HashMap; 8 | use std::fmt; 9 | use std::fmt::Display; 10 | 11 | /// Parse merged forms key, value pair get from form_urlencoded into a user defined struct 12 | /// Key and Value should be in Cow pointer 13 | /// 14 | /// # Example 15 | /// 16 | /// ``` 17 | /// # use obsidian::router::from_cow_map; 18 | /// # use hyper::{Body, Request, body, body::Buf}; 19 | /// # use url::form_urlencoded; 20 | /// # use serde::*; 21 | /// # use std::collections::HashMap; 22 | /// # use std::borrow::Cow; 23 | /// # use async_std::task; 24 | /// 25 | /// #[derive(Deserialize, Debug, PartialEq)] 26 | /// struct Example { 27 | /// field1: Vec, 28 | /// field2: i32, 29 | /// } 30 | /// task::block_on( 31 | /// async { 32 | /// let body = Request::new(Body::from("field1=1&field1=2&field2=12")).into_body(); 33 | /// 34 | /// let buf = match body::aggregate(body).await { 35 | /// Ok(buf) => buf, 36 | /// Err(e) => { 37 | /// println!("{}", e); 38 | /// panic!() 39 | /// } 40 | /// }; 41 | /// 42 | /// let mut parsed_form_map: HashMap> = HashMap::default(); 43 | /// let mut cow_form_map = HashMap::, Cow<[String]>>::default(); 44 | /// 45 | /// // Parse and merge chunks with same name key 46 | /// form_urlencoded::parse(buf.chunk()) 47 | /// .into_owned() 48 | /// .for_each(|(key, val)| { 49 | /// parsed_form_map.entry(key).or_insert(vec![]).push(val); 50 | /// }); 51 | /// 52 | /// // Wrap vec with cow pointer 53 | /// parsed_form_map.iter().for_each(|(key, val)| { 54 | /// cow_form_map 55 | /// .entry(std::borrow::Cow::from(key)) 56 | /// .or_insert(std::borrow::Cow::from(val)); 57 | /// }); 58 | /// 59 | /// let actual_result: Example = from_cow_map(&cow_form_map).unwrap(); 60 | /// let expected_result = Example{field1: vec![1,2], field2:12}; 61 | /// 62 | /// assert_eq!(actual_result, expected_result); 63 | /// }) 64 | /// ``` 65 | pub fn from_cow_map<'de, T, S: ::std::hash::BuildHasher>( 66 | s: &'de HashMap, Cow<'de, [String]>, S>, 67 | ) -> Result 68 | where 69 | T: Deserialize<'de>, 70 | { 71 | let mut deserializer = FormDeserializer::from_cow_map(s.iter().peekable()); 72 | let t = T::deserialize(&mut deserializer)?; 73 | Ok(t) 74 | } 75 | 76 | /// Deserializer for merged hashmap forms. 77 | struct FormDeserializer<'de> { 78 | input: std::iter::Peekable< 79 | std::collections::hash_map::Iter< 80 | 'de, 81 | std::borrow::Cow<'de, str>, 82 | std::borrow::Cow<'de, [String]>, 83 | >, 84 | >, 85 | } 86 | 87 | macro_rules! from_string_forms_key_impl { 88 | ($($t:ty => $method:ident)*) => {$( 89 | fn $method(self, visitor: V) -> Result 90 | where V: de::Visitor<'de> 91 | { 92 | match self.input.peek() { 93 | Some(key) => key.0.clone().into_deserializer().$method(visitor), 94 | _ => Err(Error::NoneError), 95 | } 96 | } 97 | )*} 98 | } 99 | 100 | impl<'de> FormDeserializer<'de> { 101 | pub fn from_cow_map( 102 | input: std::iter::Peekable< 103 | std::collections::hash_map::Iter< 104 | 'de, 105 | std::borrow::Cow<'de, str>, 106 | std::borrow::Cow<'de, [String]>, 107 | >, 108 | >, 109 | ) -> Self { 110 | FormDeserializer { input } 111 | } 112 | } 113 | 114 | impl<'de, 'a> de::Deserializer<'de> for &'a mut FormDeserializer<'de> { 115 | type Error = Error; 116 | 117 | fn deserialize_any(self, visitor: V) -> Result 118 | where 119 | V: Visitor<'de>, 120 | { 121 | self.deserialize_map(visitor) 122 | } 123 | 124 | fn deserialize_map(self, visitor: V) -> Result 125 | where 126 | V: Visitor<'de>, 127 | { 128 | visitor.visit_map(FromMap::new(self)) 129 | } 130 | 131 | fn deserialize_string(self, visitor: V) -> Result 132 | where 133 | V: Visitor<'de>, 134 | { 135 | self.deserialize_str(visitor) 136 | } 137 | 138 | fn deserialize_str(self, visitor: V) -> Result 139 | where 140 | V: Visitor<'de>, 141 | { 142 | match self.input.peek() { 143 | Some(key) => visitor.visit_str(key.0), 144 | _ => Err(Error::NoneError), 145 | } 146 | } 147 | 148 | fn deserialize_identifier(self, visitor: V) -> Result 149 | where 150 | V: Visitor<'de>, 151 | { 152 | self.deserialize_str(visitor) 153 | } 154 | 155 | forward_to_deserialize_any! { 156 | char 157 | option 158 | bytes 159 | byte_buf 160 | unit_struct 161 | newtype_struct 162 | tuple_struct 163 | struct 164 | tuple 165 | enum 166 | ignored_any 167 | unit 168 | seq 169 | } 170 | 171 | from_string_forms_key_impl! { 172 | bool => deserialize_bool 173 | u8 => deserialize_u8 174 | u16 => deserialize_u16 175 | u32 => deserialize_u32 176 | u64 => deserialize_u64 177 | i8 => deserialize_i8 178 | i16 => deserialize_i16 179 | i32 => deserialize_i32 180 | i64 => deserialize_i64 181 | f32 => deserialize_f32 182 | f64 => deserialize_f64 183 | } 184 | } 185 | 186 | macro_rules! from_string_forms_impl { 187 | ($($t:ty => $method:ident)*) => {$( 188 | fn $method(self, visitor: V) -> Result 189 | where V: de::Visitor<'de> 190 | { 191 | match self.input[0].parse::<$t>() { 192 | Ok(result) => result.into_deserializer().$method(visitor), 193 | Err(e) => Err(Error::Message(format!("{}", e))), 194 | } 195 | } 196 | )*} 197 | } 198 | 199 | // Deserializer for value of merged hashmap forms. 200 | struct FormValueDeserializer<'de> { 201 | input: &'de [std::string::String], 202 | } 203 | 204 | impl<'de> FormValueDeserializer<'de> { 205 | pub fn new(input: &'de [std::string::String]) -> Self { 206 | FormValueDeserializer { input } 207 | } 208 | } 209 | 210 | impl<'de> de::Deserializer<'de> for &mut FormValueDeserializer<'de> { 211 | type Error = Error; 212 | 213 | fn deserialize_any(self, visitor: V) -> Result 214 | where 215 | V: Visitor<'de>, 216 | { 217 | self.deserialize_string(visitor) 218 | } 219 | 220 | fn deserialize_seq(self, visitor: V) -> Result 221 | where 222 | V: Visitor<'de>, 223 | { 224 | visitor.visit_seq(FromSeq::new(self)) 225 | } 226 | 227 | fn deserialize_string(self, visitor: V) -> Result 228 | where 229 | V: Visitor<'de>, 230 | { 231 | visitor.visit_string(self.input[0].clone()) 232 | } 233 | 234 | fn deserialize_str(self, visitor: V) -> Result 235 | where 236 | V: Visitor<'de>, 237 | { 238 | visitor.visit_str(&self.input[0]) 239 | } 240 | 241 | fn deserialize_char(self, visitor: V) -> Result 242 | where 243 | V: Visitor<'de>, 244 | { 245 | match self.input[0].chars().next() { 246 | Some(val) => visitor.visit_char(val), 247 | _ => Err(Error::NoneError), 248 | } 249 | } 250 | 251 | fn deserialize_option(self, visitor: V) -> Result 252 | where 253 | V: Visitor<'de>, 254 | { 255 | if self.input.starts_with(&[String::default()]) { 256 | self.input = &self.input[1..]; 257 | visitor.visit_none() 258 | } else { 259 | visitor.visit_some(self) 260 | } 261 | } 262 | 263 | fn deserialize_unit(self, visitor: V) -> Result 264 | where 265 | V: Visitor<'de>, 266 | { 267 | if self.input.starts_with(&[String::default()]) { 268 | self.input = &self.input[1..]; 269 | } 270 | 271 | visitor.visit_unit() 272 | } 273 | 274 | fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result 275 | where 276 | V: Visitor<'de>, 277 | { 278 | self.deserialize_unit(visitor) 279 | } 280 | 281 | fn deserialize_newtype_struct( 282 | self, 283 | _name: &'static str, 284 | visitor: V, 285 | ) -> Result 286 | where 287 | V: Visitor<'de>, 288 | { 289 | visitor.visit_newtype_struct(self) 290 | } 291 | 292 | fn deserialize_tuple(self, _len: usize, visitor: V) -> Result 293 | where 294 | V: Visitor<'de>, 295 | { 296 | self.deserialize_seq(visitor) 297 | } 298 | 299 | fn deserialize_identifier(self, visitor: V) -> Result 300 | where 301 | V: Visitor<'de>, 302 | { 303 | self.deserialize_str(visitor) 304 | } 305 | 306 | forward_to_deserialize_any! { 307 | bytes 308 | byte_buf 309 | ignored_any 310 | map 311 | struct 312 | tuple_struct 313 | enum 314 | } 315 | 316 | from_string_forms_impl! { 317 | bool => deserialize_bool 318 | u8 => deserialize_u8 319 | u16 => deserialize_u16 320 | u32 => deserialize_u32 321 | u64 => deserialize_u64 322 | i8 => deserialize_i8 323 | i16 => deserialize_i16 324 | i32 => deserialize_i32 325 | i64 => deserialize_i64 326 | f32 => deserialize_f32 327 | f64 => deserialize_f64 328 | } 329 | } 330 | 331 | struct FromSeq<'a, 'de: 'a> { 332 | de: &'a mut FormValueDeserializer<'de>, 333 | first: bool, 334 | } 335 | 336 | impl<'a, 'de> FromSeq<'a, 'de> { 337 | fn new(de: &'a mut FormValueDeserializer<'de>) -> Self { 338 | FromSeq { de, first: true } 339 | } 340 | } 341 | 342 | impl<'de, 'a> SeqAccess<'de> for FromSeq<'a, 'de> { 343 | type Error = Error; 344 | 345 | fn next_element_seed(&mut self, seed: T) -> Result, Error> 346 | where 347 | T: DeserializeSeed<'de>, 348 | { 349 | if self.de.input.len() == 1 && !self.first { 350 | return Ok(None); 351 | } 352 | 353 | // Only start moving slices after processing 354 | if !self.first { 355 | self.de.input = &self.de.input[1..]; 356 | } 357 | self.first = false; 358 | 359 | seed.deserialize(&mut *self.de).map(Some) 360 | } 361 | } 362 | 363 | struct FromMap<'a, 'de: 'a> { 364 | de: &'a mut FormDeserializer<'de>, 365 | first: bool, 366 | } 367 | 368 | impl<'a, 'de> FromMap<'a, 'de> { 369 | fn new(de: &'a mut FormDeserializer<'de>) -> Self { 370 | FromMap { de, first: true } 371 | } 372 | } 373 | 374 | impl<'de, 'a> MapAccess<'de> for FromMap<'a, 'de> { 375 | type Error = Error; 376 | 377 | fn next_key_seed(&mut self, seed: K) -> Result, Error> 378 | where 379 | K: DeserializeSeed<'de>, 380 | { 381 | if !self.first { 382 | self.de.input.next(); 383 | } 384 | 385 | self.first = false; 386 | 387 | match self.de.input.peek() { 388 | Some(_x) => seed.deserialize(&mut *self.de).map(Some), 389 | _ => Ok(None), 390 | } 391 | } 392 | 393 | fn next_value_seed(&mut self, seed: V) -> Result 394 | where 395 | V: DeserializeSeed<'de>, 396 | { 397 | match self.de.input.peek() { 398 | Some(val) => seed.deserialize(&mut FormValueDeserializer::new(val.1)), 399 | _ => Err(Error::NoneError), 400 | } 401 | } 402 | } 403 | 404 | /// Error for request deserializer 405 | #[derive(Clone, Debug, PartialEq)] 406 | pub enum Error { 407 | Message(String), 408 | NoneError, 409 | } 410 | 411 | impl ser::Error for Error { 412 | fn custom(msg: T) -> Self { 413 | Error::Message(msg.to_string()) 414 | } 415 | } 416 | 417 | impl de::Error for Error { 418 | fn custom(msg: T) -> Self { 419 | Error::Message(msg.to_string()) 420 | } 421 | } 422 | 423 | impl Display for Error { 424 | fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 425 | match *self { 426 | Error::Message(ref msg) => formatter.write_str(msg), 427 | Error::NoneError => formatter.write_str("Input should not be None"), 428 | } 429 | } 430 | } 431 | 432 | impl std::error::Error for Error {} 433 | 434 | #[cfg(test)] 435 | mod tests { 436 | use super::*; 437 | use async_std::task; 438 | use hyper::{body, body::Buf, Body, Request}; 439 | use url::form_urlencoded; 440 | 441 | #[derive(Deserialize, Debug, PartialEq)] 442 | struct VecAndSingleVariableStruct { 443 | field1: Vec, 444 | field2: i32, 445 | } 446 | 447 | #[derive(Deserialize, Debug, PartialEq)] 448 | struct VecStruct { 449 | field1: Vec, 450 | } 451 | 452 | #[derive(Deserialize, Debug, PartialEq)] 453 | struct VecWithDefaultStruct { 454 | field1: Vec, 455 | #[serde(default)] 456 | field2: i32, 457 | } 458 | 459 | #[test] 460 | fn test_deserialize_to_struct_with_vec_and_single_variable() { 461 | task::block_on(async { 462 | let body = Request::new(Body::from("field1=abc&field1=xyz&field2=12")).into_body(); 463 | let buf = match body::aggregate(body).await { 464 | Ok(buf) => buf, 465 | Err(e) => { 466 | panic!("Body parsing fail {}", e); 467 | } 468 | }; 469 | let mut parsed_form_map: HashMap> = HashMap::default(); 470 | let mut cow_form_map = HashMap::, Cow<[String]>>::default(); 471 | // Parse and merge chunks with same name key 472 | form_urlencoded::parse(buf.chunk()) 473 | .into_owned() 474 | .for_each(|(key, val)| { 475 | parsed_form_map 476 | .entry(key) 477 | .or_insert_with(Vec::new) 478 | .push(val); 479 | }); 480 | // Wrap vec with cow pointer 481 | parsed_form_map.iter().for_each(|(key, val)| { 482 | cow_form_map 483 | .entry(std::borrow::Cow::from(key)) 484 | .or_insert_with(|| std::borrow::Cow::from(val)); 485 | }); 486 | 487 | let actual_result: VecAndSingleVariableStruct = from_cow_map(&cow_form_map).unwrap(); 488 | let expected_result = VecAndSingleVariableStruct { 489 | field1: vec!["abc".to_string(), "xyz".to_string()], 490 | field2: 12, 491 | }; 492 | assert_eq!(actual_result, expected_result); 493 | }) 494 | } 495 | 496 | #[test] 497 | fn test_deserialize_to_struct_with_vec() { 498 | task::block_on(async { 499 | let body = Request::new(Body::from("field1=1&field1=2")).into_body(); 500 | let buf = match body::aggregate(body).await { 501 | Ok(buf) => buf, 502 | Err(e) => { 503 | panic!("Body parsing fail {}", e); 504 | } 505 | }; 506 | let mut parsed_form_map: HashMap> = HashMap::default(); 507 | let mut cow_form_map = HashMap::, Cow<[String]>>::default(); 508 | // Parse and merge chunks with same name key 509 | form_urlencoded::parse(buf.chunk()) 510 | .into_owned() 511 | .for_each(|(key, val)| { 512 | parsed_form_map 513 | .entry(key) 514 | .or_insert_with(Vec::new) 515 | .push(val); 516 | }); 517 | // Wrap vec with cow pointer 518 | parsed_form_map.iter().for_each(|(key, val)| { 519 | cow_form_map 520 | .entry(std::borrow::Cow::from(key)) 521 | .or_insert_with(|| std::borrow::Cow::from(val)); 522 | }); 523 | 524 | let actual_result: VecStruct = from_cow_map(&cow_form_map).unwrap(); 525 | let expected_result = VecStruct { field1: vec![1, 2] }; 526 | assert_eq!(actual_result, expected_result); 527 | }) 528 | } 529 | 530 | #[test] 531 | fn test_deserialize_to_struct_with_extra_form_value() { 532 | task::block_on(async { 533 | let body = Request::new(Body::from("field1=1&field1=2&field2=12")).into_body(); 534 | let buf = match body::aggregate(body).await { 535 | Ok(buf) => buf, 536 | Err(e) => { 537 | panic!("Body parsing fail {}", e); 538 | } 539 | }; 540 | let mut parsed_form_map: HashMap> = HashMap::default(); 541 | let mut cow_form_map = HashMap::, Cow<[String]>>::default(); 542 | // Parse and merge chunks with same name key 543 | form_urlencoded::parse(buf.chunk()) 544 | .into_owned() 545 | .for_each(|(key, val)| { 546 | parsed_form_map 547 | .entry(key) 548 | .or_insert_with(Vec::new) 549 | .push(val); 550 | }); 551 | // Wrap vec with cow pointer 552 | parsed_form_map.iter().for_each(|(key, val)| { 553 | cow_form_map 554 | .entry(std::borrow::Cow::from(key)) 555 | .or_insert_with(|| std::borrow::Cow::from(val)); 556 | }); 557 | 558 | let actual_result: VecStruct = from_cow_map(&cow_form_map).unwrap(); 559 | let expected_result = VecStruct { field1: vec![1, 2] }; 560 | assert_eq!(actual_result, expected_result); 561 | }) 562 | } 563 | 564 | #[test] 565 | fn test_deserialize_to_struct_with_extra_struct_field() { 566 | task::block_on(async { 567 | let body = Request::new(Body::from("field1=1&field1=2")).into_body(); 568 | let buf = match body::aggregate(body).await { 569 | Ok(buf) => buf, 570 | Err(e) => { 571 | panic!("Body parsing fail {}", e); 572 | } 573 | }; 574 | let mut parsed_form_map: HashMap> = HashMap::default(); 575 | let mut cow_form_map = HashMap::, Cow<[String]>>::default(); 576 | // Parse and merge chunks with same name key 577 | form_urlencoded::parse(buf.chunk()) 578 | .into_owned() 579 | .for_each(|(key, val)| { 580 | parsed_form_map 581 | .entry(key) 582 | .or_insert_with(Vec::new) 583 | .push(val); 584 | }); 585 | // Wrap vec with cow pointer 586 | parsed_form_map.iter().for_each(|(key, val)| { 587 | cow_form_map 588 | .entry(std::borrow::Cow::from(key)) 589 | .or_insert_with(|| std::borrow::Cow::from(val)); 590 | }); 591 | 592 | let actual_result: VecWithDefaultStruct = from_cow_map(&cow_form_map).unwrap(); 593 | let expected_result = VecWithDefaultStruct { 594 | field1: vec![1, 2], 595 | field2: i32::default(), 596 | }; 597 | assert_eq!(actual_result, expected_result); 598 | }) 599 | } 600 | 601 | #[test] 602 | fn test_deserialize_to_map_type() { 603 | task::block_on(async { 604 | let body = Request::new(Body::from("field1=1&field1=2&field2=3")).into_body(); 605 | let buf = match body::aggregate(body).await { 606 | Ok(buf) => buf, 607 | Err(e) => { 608 | panic!("Body parsing fail {}", e); 609 | } 610 | }; 611 | let mut parsed_form_map: HashMap> = HashMap::default(); 612 | let mut cow_form_map = HashMap::, Cow<[String]>>::default(); 613 | // Parse and merge chunks with same name key 614 | form_urlencoded::parse(buf.chunk()) 615 | .into_owned() 616 | .for_each(|(key, val)| { 617 | parsed_form_map 618 | .entry(key) 619 | .or_insert_with(Vec::new) 620 | .push(val); 621 | }); 622 | // Wrap vec with cow pointer 623 | parsed_form_map.iter().for_each(|(key, val)| { 624 | cow_form_map 625 | .entry(std::borrow::Cow::from(key)) 626 | .or_insert_with(|| std::borrow::Cow::from(val)); 627 | }); 628 | 629 | let actual_result: HashMap> = from_cow_map(&cow_form_map).unwrap(); 630 | let mut expected_result: HashMap> = HashMap::default(); 631 | 632 | expected_result 633 | .entry("field1".to_string()) 634 | .or_insert_with(Vec::new) 635 | .push(1); 636 | expected_result 637 | .entry("field1".to_string()) 638 | .or_insert_with(Vec::new) 639 | .push(2); 640 | expected_result 641 | .entry("field2".to_string()) 642 | .or_insert_with(Vec::new) 643 | .push(3); 644 | assert_eq!(actual_result, expected_result); 645 | }) 646 | } 647 | } 648 | -------------------------------------------------------------------------------- /src/router/resource.rs: -------------------------------------------------------------------------------- 1 | use hyper::Method; 2 | use std::collections::HashMap; 3 | 4 | use super::Route; 5 | 6 | /// Resource acts as the intermidiate interface for interaction of routing data structure 7 | /// Resource is binding with the path and handling all of the request method for that path 8 | #[derive(Clone, Debug)] 9 | pub struct Resource { 10 | route_map: HashMap, 11 | } 12 | 13 | impl Default for Resource { 14 | fn default() -> Self { 15 | Resource { 16 | route_map: HashMap::new(), 17 | } 18 | } 19 | } 20 | 21 | impl Resource { 22 | pub fn add_route(&mut self, method: Method, route: Route) -> Option { 23 | self.route_map.insert(method, route) 24 | } 25 | 26 | pub fn get_route(&self, method: &Method) -> Option<&Route> { 27 | self.route_map.get(&method) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/router/responder.rs: -------------------------------------------------------------------------------- 1 | use super::Response; 2 | use super::ResponseBody; 3 | use hyper::{header, StatusCode}; 4 | 5 | pub trait Responder { 6 | fn respond_to(self) -> Response; 7 | fn with_status(self, status: StatusCode) -> Response 8 | where 9 | Self: Responder + ResponseBody + Sized, 10 | { 11 | Response::new(self).set_status(status) 12 | } 13 | 14 | fn with_header(self, key: header::HeaderName, value: &'static str) -> Response 15 | where 16 | Self: Responder + ResponseBody + Sized, 17 | { 18 | Response::new(self).set_header(key, value) 19 | } 20 | 21 | fn with_headers(self, headers: Vec<(header::HeaderName, &'static str)>) -> Response 22 | where 23 | Self: Responder + ResponseBody + Sized, 24 | { 25 | Response::new(self).set_headers(headers) 26 | } 27 | 28 | fn with_headers_str(self, headers: Vec<(&'static str, &'static str)>) -> Response 29 | where 30 | Self: Responder + ResponseBody + Sized, 31 | { 32 | Response::new(self).set_headers_str(headers) 33 | } 34 | } 35 | 36 | impl Responder for Response { 37 | fn respond_to(self) -> Response { 38 | self 39 | } 40 | } 41 | 42 | impl Responder for String { 43 | fn respond_to(self) -> Response { 44 | Response::new(self).set_content_type("text/plain; charset=utf-8") 45 | } 46 | } 47 | 48 | impl Responder for &'static str { 49 | fn respond_to(self) -> Response { 50 | self.to_string().respond_to() 51 | } 52 | } 53 | 54 | impl Responder for () { 55 | fn respond_to(self) -> Response { 56 | Response::new(()) 57 | } 58 | } 59 | 60 | impl Responder for (StatusCode, T) 61 | where 62 | T: Responder + ResponseBody, 63 | { 64 | fn respond_to(self) -> Response { 65 | let (status_code, body) = self; 66 | Response::new(body).set_status(status_code) 67 | } 68 | } 69 | 70 | impl Responder for Vec { 71 | fn respond_to(self) -> Response { 72 | match serde_json::to_string(&self) { 73 | Ok(json) => json.with_status(StatusCode::OK).respond_to(), 74 | Err(e) => { 75 | eprintln!("serializing failed: {}", e); 76 | let error = e.to_string(); 77 | error 78 | .with_status(StatusCode::INTERNAL_SERVER_ERROR) 79 | .respond_to() 80 | } 81 | } 82 | } 83 | } 84 | 85 | impl Responder for StatusCode { 86 | fn respond_to(self) -> Response { 87 | ().with_status(self).respond_to() 88 | } 89 | } 90 | 91 | impl Responder for Option { 92 | fn respond_to(self) -> Response { 93 | match self { 94 | Some(resp) => resp.respond_to(), 95 | None => "Not Found" 96 | .to_string() 97 | .with_status(StatusCode::NOT_FOUND) 98 | .respond_to(), 99 | } 100 | } 101 | } 102 | 103 | impl Responder for Option<&'static str> { 104 | fn respond_to(self) -> Response { 105 | match self { 106 | Some(resp) => resp.respond_to(), 107 | None => "Not Found" 108 | .to_string() 109 | .with_status(StatusCode::NOT_FOUND) 110 | .respond_to(), 111 | } 112 | } 113 | } 114 | 115 | #[cfg(test)] 116 | mod test { 117 | use super::*; 118 | use hyper::StatusCode; 119 | 120 | #[test] 121 | fn test_str_responder() { 122 | let response = "Hello World".respond_to(); 123 | assert_eq!(response.status(), StatusCode::OK); 124 | // TODO: add testing for body once the Responder is refactored 125 | } 126 | 127 | #[test] 128 | fn test_string_responder() { 129 | let response = "Hello World".to_string().respond_to(); 130 | assert_eq!(response.status(), StatusCode::OK); 131 | // TODO: add testing for body once the Responder is refactored 132 | } 133 | 134 | // #[test] 135 | // fn test_option_responder() { 136 | // let some_result = Some("Hello World").respond_to(); 137 | // if let Ok(response) = some_result { 138 | // assert_eq!(response.status(), StatusCode::OK); 139 | // // TODO: add testing for body once the Responder is refactored 140 | // } 141 | 142 | // let none_result = None::.respond_to(); 143 | // if let Ok(response) = none_result { 144 | // assert_eq!(response.status(), StatusCode::NOT_FOUND); 145 | // // TODO: add testing for body once the Responder is refactored 146 | // } 147 | // } 148 | 149 | // #[test] 150 | // fn test_result_responder() { 151 | // let ok_result = Ok::<&str, &str>("Hello World").respond_to(); 152 | // if let Ok(response) = ok_result { 153 | // assert_eq!(response.status(), StatusCode::OK); 154 | // // TODO: add testing for body once the Responder is refactored 155 | // } 156 | 157 | // let err_result = Err::<&str, &str>("Some error").respond_to(); 158 | // if let Ok(response) = err_result { 159 | // assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); 160 | // // TODO: add testing for body once the Responder is refactored 161 | // } 162 | // } 163 | 164 | #[test] 165 | fn test_responder_with_custom_status() { 166 | let response = "Test".with_status(StatusCode::CREATED).respond_to(); 167 | assert_eq!(response.status(), StatusCode::CREATED); 168 | } 169 | 170 | #[test] 171 | fn test_responder_with_custom_header() { 172 | let response = "Test" 173 | .with_header(header::CONTENT_TYPE, "application/json") 174 | .respond_to(); 175 | assert_eq!(response.status(), StatusCode::OK); 176 | assert!(response 177 | .headers() 178 | .as_ref() 179 | .unwrap() 180 | .contains(&(header::CONTENT_TYPE, "application/json"))); 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /src/router/response.rs: -------------------------------------------------------------------------------- 1 | use super::ResponseBody; 2 | 3 | use async_std::fs; 4 | use http::StatusCode; 5 | use hyper::{header, Body}; 6 | use serde::ser::Serialize; 7 | 8 | #[derive(Debug)] 9 | pub struct Response { 10 | body: Body, 11 | status: StatusCode, 12 | headers: Option>, 13 | } 14 | 15 | impl Response { 16 | pub fn new(body: impl ResponseBody) -> Self { 17 | Response { 18 | body: body.into_body(), 19 | status: StatusCode::OK, 20 | headers: None, 21 | } 22 | } 23 | 24 | pub fn status(&self) -> StatusCode { 25 | self.status 26 | } 27 | 28 | pub fn status_mut(&mut self) -> &mut StatusCode { 29 | &mut self.status 30 | } 31 | 32 | pub fn body(self) -> Body { 33 | self.body 34 | } 35 | 36 | pub fn headers(&self) -> &Option> { 37 | &self.headers 38 | } 39 | 40 | pub fn headers_mut(&mut self) -> &mut Option> { 41 | &mut self.headers 42 | } 43 | 44 | pub fn with_status(self, status: StatusCode) -> Self { 45 | self.set_status(status) 46 | } 47 | 48 | pub fn set_status(mut self, status: http::StatusCode) -> Self { 49 | self.status = status; 50 | self 51 | } 52 | 53 | pub fn set_body(mut self, body: impl ResponseBody) -> Self { 54 | self.body = body.into_body(); 55 | self 56 | } 57 | 58 | pub fn set_header(mut self, key: header::HeaderName, value: &'static str) -> Self { 59 | match self.headers { 60 | Some(ref mut x) => x.push((key, value)), 61 | None => self.headers = Some(vec![(key, value)]), 62 | }; 63 | self 64 | } 65 | 66 | // Alias set_header method 67 | pub fn with_header(self, key: header::HeaderName, value: &'static str) -> Self { 68 | self.set_header(key, value) 69 | } 70 | 71 | pub fn set_header_str(self, key: &'static str, value: &'static str) -> Self { 72 | self.set_header( 73 | header::HeaderName::from_bytes(key.as_bytes()).unwrap(), 74 | value, 75 | ) 76 | } 77 | 78 | // Alias set_header_str method 79 | pub fn with_header_str(self, key: &'static str, value: &'static str) -> Self { 80 | self.set_header_str(key, value) 81 | } 82 | 83 | pub fn set_content_type(self, content_type: &'static str) -> Self { 84 | self.set_header(header::CONTENT_TYPE, content_type) 85 | } 86 | 87 | pub fn set_headers(mut self, headers: Vec<(header::HeaderName, &'static str)>) -> Self { 88 | match self.headers { 89 | Some(ref mut x) => x.extend_from_slice(&headers), 90 | None => self.headers = Some(headers), 91 | }; 92 | self 93 | } 94 | 95 | // Alias set_headers method 96 | pub fn with_headers(self, headers: Vec<(header::HeaderName, &'static str)>) -> Self { 97 | self.set_headers(headers) 98 | } 99 | 100 | pub fn set_headers_str(mut self, headers: Vec<(&'static str, &'static str)>) -> Self { 101 | let values: Vec<(header::HeaderName, &'static str)> = headers 102 | .iter() 103 | .map(|(k, v)| (header::HeaderName::from_bytes(k.as_bytes()).unwrap(), *v)) 104 | .collect(); 105 | 106 | match self.headers { 107 | Some(ref mut x) => x.extend_from_slice(&values), 108 | None => self.headers = Some(values), 109 | }; 110 | self 111 | } 112 | 113 | // Alias set_headers_str method 114 | pub fn with_headers_str(self, headers: Vec<(&'static str, &'static str)>) -> Self { 115 | self.set_headers_str(headers) 116 | } 117 | 118 | pub fn html(self, body: impl ResponseBody) -> Self { 119 | self.set_content_type("text/html").set_body(body) 120 | } 121 | 122 | pub fn json(self, body: impl Serialize) -> Self { 123 | match serde_json::to_string(&body) { 124 | Ok(val) => self.set_content_type("application/json").set_body(val), 125 | Err(err) => self 126 | .set_content_type("application/json") 127 | .set_body(err.to_string()) 128 | .set_status(StatusCode::INTERNAL_SERVER_ERROR), 129 | } 130 | } 131 | 132 | pub async fn file(self, file_path: &str) -> Self { 133 | match fs::read_to_string(file_path.to_string()).await { 134 | Ok(content) => self.set_body(content), 135 | Err(err) => { 136 | dbg!(&err); 137 | self.set_body(err.to_string()) 138 | .set_status(StatusCode::NOT_FOUND) 139 | } 140 | } 141 | } 142 | 143 | // Utilities 144 | pub fn ok() -> Self { 145 | Response::new(()).with_status(StatusCode::OK) 146 | } 147 | 148 | pub fn created() -> Self { 149 | Response::new(()).with_status(StatusCode::CREATED) 150 | } 151 | 152 | pub fn internal_server_error() -> Self { 153 | Response::new(()).with_status(StatusCode::INTERNAL_SERVER_ERROR) 154 | } 155 | } 156 | 157 | #[cfg(test)] 158 | mod test { 159 | use super::*; 160 | use hyper::StatusCode; 161 | use serde::*; 162 | 163 | #[test] 164 | fn test_response() { 165 | let response = Response::new("Hello World"); 166 | assert_eq!(response.status(), StatusCode::OK); 167 | // TODO: add testing for body once the Responder is refactored 168 | } 169 | 170 | #[test] 171 | fn test_response_utilities_status() { 172 | assert_eq!(Response::ok().status(), StatusCode::OK); 173 | assert_eq!(Response::created().status(), StatusCode::CREATED); 174 | assert_eq!( 175 | Response::internal_server_error().status(), 176 | StatusCode::INTERNAL_SERVER_ERROR 177 | ); 178 | } 179 | 180 | #[test] 181 | fn test_complete_response() { 182 | #[derive(Serialize, Deserialize, Debug)] 183 | struct Person { 184 | name: String, 185 | age: i8, 186 | } 187 | 188 | let person = Person { 189 | name: String::from("Jun Kai"), 190 | age: 27, 191 | }; 192 | let response = Response::created() 193 | .set_header(header::AUTHORIZATION, "token") 194 | .json(person); 195 | 196 | assert_eq!(response.status(), StatusCode::CREATED); 197 | assert!(response 198 | .headers() 199 | .as_ref() 200 | .unwrap() 201 | .contains(&(header::CONTENT_TYPE, "application/json"))); 202 | assert!(response 203 | .headers() 204 | .as_ref() 205 | .unwrap() 206 | .contains(&(header::AUTHORIZATION, "token"))); 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /src/router/response_body.rs: -------------------------------------------------------------------------------- 1 | use hyper::Body; 2 | 3 | pub trait ResponseBody { 4 | fn into_body(self) -> Body; 5 | } 6 | 7 | impl ResponseBody for () { 8 | fn into_body(self) -> Body { 9 | Body::empty() 10 | } 11 | } 12 | 13 | impl ResponseBody for &'static str { 14 | fn into_body(self) -> Body { 15 | Body::from(self) 16 | } 17 | } 18 | 19 | impl ResponseBody for String { 20 | fn into_body(self) -> Body { 21 | Body::from(self) 22 | } 23 | } 24 | 25 | impl ResponseBody for Vec { 26 | fn into_body(self) -> Body { 27 | match serde_json::to_string(&self) { 28 | Ok(json) => Body::from(json), 29 | Err(e) => { 30 | eprintln!("serializing failed: {}", e); 31 | Body::from(e.to_string()) 32 | } 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/router/route.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use super::Handler; 4 | use crate::Method; 5 | 6 | pub struct Route { 7 | pub method: Method, 8 | pub handler: Arc, 9 | } 10 | 11 | impl std::fmt::Debug for Route { 12 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 13 | write!(f, "Route {{ method: {} }}", self.method) 14 | } 15 | } 16 | 17 | impl Clone for Route { 18 | fn clone(&self) -> Route { 19 | Route { 20 | method: self.method.clone(), 21 | handler: self.handler.clone(), 22 | } 23 | } 24 | } 25 | 26 | impl Route { 27 | pub fn new(method: Method, handler: impl Handler) -> Self { 28 | Route { 29 | method, 30 | handler: Arc::new(handler), 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/router/route_trie.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::fmt; 3 | use std::sync::Arc; 4 | 5 | use hyper::Method; 6 | 7 | use crate::middleware::Middleware; 8 | use crate::router::Resource; 9 | use crate::router::Route; 10 | use crate::ObsidianError; 11 | 12 | #[derive(Clone, Default)] 13 | pub struct RouteValue { 14 | middlewares: Vec>, 15 | route: Resource, 16 | } 17 | 18 | impl fmt::Debug for RouteValue { 19 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 20 | write!(f, "") 21 | } 22 | } 23 | 24 | impl RouteValue { 25 | pub fn new(middlewares: Vec>, route: Resource) -> Self { 26 | RouteValue { middlewares, route } 27 | } 28 | } 29 | 30 | pub struct RouteValueResult { 31 | route_value: RouteValue, 32 | params: HashMap, 33 | } 34 | 35 | impl RouteValueResult { 36 | pub fn new(route_value: RouteValue, params: HashMap) -> Self { 37 | RouteValueResult { 38 | route_value, 39 | params, 40 | } 41 | } 42 | 43 | pub fn get_route(&self, method: &Method) -> Option<&Route> { 44 | self.route_value.route.get_route(method) 45 | } 46 | 47 | pub fn get_middlewares(&self) -> &Vec> { 48 | &self.route_value.middlewares 49 | } 50 | 51 | pub fn get_params(&self) -> HashMap { 52 | self.params.clone() 53 | } 54 | } 55 | 56 | #[derive(Clone, Debug)] 57 | pub struct RouteTrie { 58 | head: Node, 59 | } 60 | 61 | impl RouteTrie { 62 | pub fn new() -> Self { 63 | RouteTrie { 64 | head: Node::new("/".to_string(), None), 65 | } 66 | } 67 | 68 | /// Insert middleware into root node 69 | pub fn insert_default_middleware(&mut self, middleware: impl Middleware) { 70 | match &mut self.head.value { 71 | Some(val) => { 72 | val.middlewares.push(Arc::new(middleware)); 73 | } 74 | None => { 75 | let mut val = RouteValue::default(); 76 | val.middlewares.push(Arc::new(middleware)); 77 | 78 | self.head.value = Some(val); 79 | } 80 | } 81 | } 82 | 83 | /// Insert route values into the trie 84 | /// Panic if ambigous definition is detected 85 | pub fn insert_route(&mut self, path: &str, route: Route) { 86 | // Split path string and drop additional '/' 87 | let mut split_key = path.split('/').filter(|key| !key.is_empty()).peekable(); 88 | 89 | split_key.clone().enumerate().for_each(|(pos, x)| { 90 | if x.contains('*') { 91 | if x.len() != 1 { 92 | panic!("ERROR: Consisting * in route name at: {}", path); 93 | } else if pos != split_key.clone().count() - 1 { 94 | panic!("ERROR: * must be in the last at: {}", path); 95 | } 96 | } 97 | }); 98 | 99 | let mut curr_node = &mut self.head; 100 | 101 | // if the path is "/" 102 | if split_key.peek().is_none() { 103 | self.insert_default_route(route); 104 | return; 105 | } 106 | 107 | while let Some(k) = split_key.next() { 108 | match curr_node.process_insertion(k) { 109 | Ok(next_node) => { 110 | if split_key.peek().is_none() { 111 | match &mut next_node.value { 112 | Some(val) => { 113 | if let Some(duplicated) = 114 | val.route.add_route(route.method.clone(), route) 115 | { 116 | panic!( 117 | "Duplicated route method '{}' at '{}' detected", 118 | duplicated.method, path 119 | ); 120 | } 121 | } 122 | None => { 123 | let mut next_node_val = RouteValue::default(); 124 | if let Some(duplicated) = 125 | next_node_val.route.add_route(route.method.clone(), route) 126 | { 127 | panic!( 128 | "Duplicated route method '{}' at '{}' detected", 129 | duplicated.method, path 130 | ); 131 | } 132 | 133 | next_node.value = Some(next_node_val); 134 | } 135 | } 136 | break; 137 | } 138 | curr_node = next_node; 139 | } 140 | Err(err) => { 141 | panic!("Insert Route: {} at {}", err, path); 142 | } 143 | } 144 | } 145 | } 146 | 147 | /// Insert middleware into specific node 148 | pub fn insert_middleware(&mut self, path: &str, middleware: impl Middleware) { 149 | // Split key and drop additional '/' 150 | let split_key = path.split('/'); 151 | let mut split_key = split_key.filter(|key| !key.is_empty()).peekable(); 152 | 153 | split_key.clone().enumerate().for_each(|(pos, key)| { 154 | if key.contains('*') { 155 | if key.len() != 1 { 156 | panic!("ERROR: Consisting * in route name at: {}", path); 157 | } else if pos != split_key.clone().count() - 1 { 158 | panic!("ERROR: * must be in the last at: {}", path); 159 | } 160 | } 161 | }); 162 | 163 | let mut curr_node = &mut self.head; 164 | 165 | while let Some(k) = split_key.next() { 166 | match curr_node.process_insertion(k) { 167 | Ok(next_node) => { 168 | if split_key.peek().is_none() { 169 | match &mut next_node.value { 170 | Some(val) => { 171 | val.middlewares.push(Arc::new(middleware)); 172 | } 173 | None => { 174 | let mut next_node_val = RouteValue::default(); 175 | next_node_val.middlewares.push(Arc::new(middleware)); 176 | 177 | next_node.value = Some(next_node_val); 178 | } 179 | } 180 | break; 181 | } 182 | curr_node = next_node; 183 | } 184 | Err(err) => { 185 | panic!("Middleware: {} at {}", err, path); 186 | } 187 | } 188 | } 189 | } 190 | 191 | /// Search node through the provided key 192 | /// Middleware will be accumulated throughout the search path 193 | pub fn search_route(&self, path: &str) -> Option { 194 | // Split key and drop additional '/' 195 | let split_key = path.split('/'); 196 | let mut split_key = split_key 197 | .filter(|key| !key.is_empty()) 198 | .collect::>(); 199 | 200 | let mut curr_node = &self.head; 201 | let mut params = HashMap::default(); 202 | let mut middlewares = vec![]; 203 | 204 | match &curr_node.value { 205 | Some(val) => { 206 | middlewares.append(&mut val.middlewares.clone()); 207 | } 208 | None => {} 209 | } 210 | 211 | if !split_key.is_empty() { 212 | match curr_node.get_next_node(&mut split_key, &mut params, &mut middlewares, false) { 213 | Some(handler_node) => { 214 | curr_node = handler_node; 215 | } 216 | None => { 217 | // Path is not registered 218 | return None; 219 | } 220 | } 221 | } 222 | 223 | match &curr_node.value { 224 | Some(val) => { 225 | let route_val = RouteValue::new(middlewares, val.route.clone()); 226 | 227 | Some(RouteValueResult::new(route_val, params)) 228 | } 229 | None => None, 230 | } 231 | } 232 | 233 | /// Insert src trie into the des as a child trie 234 | /// src will be under the node of des with the key path 235 | /// 236 | /// For example, /src/ -> /des/ with 'example' key path 237 | /// src will be located at /des/example/src/ 238 | pub fn insert_sub_route(des: &mut Self, path: &str, src: Self) { 239 | // Split key and drop additional '/' 240 | let split_key = path.split('/'); 241 | let mut split_key = split_key.filter(|key| !key.is_empty()).peekable(); 242 | 243 | split_key.clone().enumerate().for_each(|(pos, x)| { 244 | if x.contains('*') { 245 | if x.len() != 1 { 246 | panic!("ERROR: Consisting * in route name at: {}", path); 247 | } else if pos != split_key.clone().count() - 1 { 248 | panic!("ERROR: * must be in the last at: {}", path); 249 | } 250 | } 251 | }); 252 | 253 | let mut curr_node = &mut des.head; 254 | 255 | if split_key.peek().is_none() { 256 | des.head = src.head; 257 | return; 258 | } 259 | 260 | while let Some(k) = split_key.next() { 261 | match curr_node.process_insertion(k) { 262 | Ok(next_node) => { 263 | if split_key.peek().is_none() { 264 | if next_node.value.is_some() || !next_node.child_nodes.is_empty() { 265 | panic!("There is conflict between main router and sub router at '{}'. Make sure main router does not consist any routing data in '{}'.", path, path); 266 | } 267 | 268 | next_node.value = src.head.value; 269 | next_node.child_nodes = src.head.child_nodes; 270 | break; 271 | } 272 | curr_node = next_node; 273 | } 274 | Err(err) => { 275 | panic!("SubRouter: {} at {}", err, path); 276 | } 277 | } 278 | } 279 | } 280 | 281 | fn insert_default_route(&mut self, route: Route) { 282 | match &mut self.head.value { 283 | Some(val) => { 284 | if let Some(duplicated) = val.route.add_route(route.method.clone(), route) { 285 | panic!( 286 | "Duplicated route method '{}' at '/' detected", 287 | duplicated.method 288 | ); 289 | } 290 | } 291 | None => { 292 | let mut val = RouteValue::default(); 293 | if let Some(duplicated) = val.route.add_route(route.method.clone(), route) { 294 | panic!( 295 | "Duplicated route method '{}' at '/' detected", 296 | duplicated.method 297 | ); 298 | } 299 | 300 | self.head.value = Some(val); 301 | } 302 | } 303 | } 304 | } 305 | 306 | #[derive(Clone, Debug)] 307 | struct Node { 308 | key: String, 309 | value: Option, 310 | child_nodes: Vec, 311 | } 312 | 313 | impl Node { 314 | fn new(key: String, value: Option) -> Self { 315 | Node { 316 | key, 317 | value, 318 | child_nodes: Vec::default(), 319 | } 320 | } 321 | 322 | fn is_param(&self) -> bool { 323 | self.key.chars().next().unwrap_or(' ') == ':' 324 | } 325 | 326 | /// Process the side effects of node insertion 327 | fn process_insertion(&mut self, key: &str) -> Result<&mut Self, ObsidianError> { 328 | let action = self.get_insertion_action(key); 329 | 330 | match action.name { 331 | ActionName::CreateNewNode => { 332 | let new_node = Self::new(key.to_string(), None); 333 | 334 | match key { 335 | k if k == "*" => { 336 | self.child_nodes.push(new_node); 337 | 338 | if let Some(node) = self.child_nodes.last_mut() { 339 | return Ok(node); 340 | }; 341 | } 342 | _ => { 343 | self.child_nodes.insert(0, new_node); 344 | 345 | if let Some(node) = self.child_nodes.first_mut() { 346 | return Ok(node); 347 | }; 348 | } 349 | } 350 | } 351 | ActionName::NextNode => { 352 | if let Some(node) = self.child_nodes.get_mut(action.payload.node_index) { 353 | return Ok(node); 354 | }; 355 | } 356 | ActionName::SplitKey => { 357 | if let Some(node) = self.child_nodes.get_mut(action.payload.node_index) { 358 | return node.process_insertion(&key[action.payload.match_count..]); 359 | }; 360 | } 361 | ActionName::SplitNode => { 362 | if let Some(node) = self.child_nodes.get_mut(action.payload.node_index) { 363 | let count = action.payload.match_count; 364 | let child_key = node.key[count..].to_string(); 365 | let new_key = key[count..].to_string(); 366 | node.key = key[..count].to_string(); 367 | 368 | let mut inter_node = Self::new(child_key, None); 369 | 370 | // Move out the previous child and transfer to intermediate node 371 | inter_node.child_nodes = std::mem::take(&mut node.child_nodes); 372 | inter_node.value = std::mem::replace(&mut node.value, None); 373 | 374 | node.child_nodes.insert(0, inter_node); 375 | 376 | // In the case of insert key length less than matched node key length 377 | if new_key.is_empty() { 378 | return Ok(node); 379 | } 380 | 381 | let new_node = Self::new(new_key, None); 382 | 383 | node.child_nodes.insert(0, new_node); 384 | if let Some(result_node) = node.child_nodes.first_mut() { 385 | return Ok(result_node); 386 | } 387 | }; 388 | } 389 | ActionName::Error => { 390 | if let Some(node) = self.child_nodes.get(action.payload.node_index) { 391 | return Err(ObsidianError::GeneralError(format!( 392 | "ERROR: Ambigous definition between {} and {}", 393 | key, node.key 394 | ))); 395 | } 396 | } 397 | } 398 | 399 | unreachable!(); 400 | } 401 | 402 | /// Determine the action required to be performed for the new route path 403 | fn get_insertion_action(&self, key: &str) -> Action { 404 | for (index, node) in self.child_nodes.iter().enumerate() { 405 | let is_param = node.is_param() || key.chars().next().unwrap_or(' ') == ':'; 406 | if is_param { 407 | // Only allow one param leaf in one children series 408 | if key == node.key { 409 | return Action::new(ActionName::NextNode, ActionPayload::new(0, index)); 410 | } else { 411 | return Action::new(ActionName::Error, ActionPayload::new(0, index)); 412 | } 413 | } 414 | 415 | let mut temp_key_chars = key.chars(); 416 | let mut count = 0; 417 | 418 | // match characters 419 | for k in node.key.chars() { 420 | let t_k = match temp_key_chars.next() { 421 | Some(key) => key, 422 | None => break, 423 | }; 424 | 425 | if t_k == k { 426 | count += t_k.len_utf8(); 427 | } else { 428 | break; 429 | } 430 | } 431 | 432 | match count { 433 | x if x == key.len() && x == node.key.len() => { 434 | return Action::new(ActionName::NextNode, ActionPayload::new(x, index)) 435 | } 436 | x if x == node.key.len() => { 437 | return Action::new(ActionName::SplitKey, ActionPayload::new(x, index)) 438 | } 439 | x if x != 0 => { 440 | return Action::new(ActionName::SplitNode, ActionPayload::new(x, index)) 441 | } 442 | _ => {} 443 | } 444 | } 445 | 446 | // No child node matched the key, creates new node 447 | Action::new(ActionName::CreateNewNode, ActionPayload::new(0, 0)) 448 | } 449 | 450 | // Helper function to consume the whole key and get the next available node 451 | fn get_next_node( 452 | &self, 453 | key: &mut Vec<&str>, 454 | params: &mut HashMap, 455 | middlewares: &mut Vec>, 456 | is_break_parent: bool, 457 | ) -> Option<&Self> { 458 | let curr_key = key.remove(0); 459 | 460 | for node in self.child_nodes.iter() { 461 | let mut break_key = false; 462 | 463 | if !is_break_parent { 464 | // Check param 465 | if node.is_param() { 466 | if key.is_empty() { 467 | match &node.value { 468 | Some(curr_val) => { 469 | params.insert(node.key[1..].to_string(), curr_key.to_string()); 470 | middlewares.append(&mut curr_val.middlewares.clone()); 471 | return Some(node); 472 | } 473 | None => { 474 | continue; 475 | } 476 | } 477 | } else { 478 | match node.get_next_node(key, params, middlewares, break_key) { 479 | Some(final_val) => { 480 | params.insert(node.key[1..].to_string(), curr_key.to_string()); 481 | 482 | match &node.value { 483 | Some(curr_val) => { 484 | middlewares.append(&mut curr_val.middlewares.clone()); 485 | } 486 | None => {} 487 | } 488 | 489 | return Some(final_val); 490 | } 491 | None => { 492 | continue; 493 | } 494 | } 495 | } 496 | } 497 | 498 | // Check wildcard 499 | if node.key == "*" { 500 | match &node.value { 501 | Some(curr_val) => { 502 | middlewares.append(&mut curr_val.middlewares.clone()); 503 | } 504 | None => {} 505 | } 506 | 507 | return Some(node); 508 | } 509 | } 510 | 511 | let mut temp_key_chars = curr_key.chars(); 512 | let mut count = 0; 513 | 514 | // match characters 515 | for k in node.key.chars() { 516 | let t_k = match temp_key_chars.next() { 517 | Some(key) => key, 518 | None => break, 519 | }; 520 | 521 | if t_k == k { 522 | count += t_k.len_utf8(); 523 | } else { 524 | break; 525 | } 526 | } 527 | 528 | if count == node.key.len() && count != curr_key.len() { 529 | // break key 530 | break_key = true; 531 | key.insert(0, &curr_key[count..]); 532 | } 533 | 534 | if count != 0 && count == node.key.len() { 535 | if key.is_empty() { 536 | match &node.value { 537 | Some(curr_val) => { 538 | middlewares.append(&mut curr_val.middlewares.clone()); 539 | return Some(node); 540 | } 541 | None => { 542 | for child in node.child_nodes.iter() { 543 | if child.key == "*" { 544 | if let Some(child_val) = &child.value { 545 | middlewares.append(&mut child_val.middlewares.clone()); 546 | return Some(child); 547 | } 548 | } 549 | } 550 | 551 | continue; 552 | } 553 | } 554 | } else if let Some(final_val) = 555 | node.get_next_node(key, params, middlewares, break_key) 556 | { 557 | if let Some(curr_val) = &node.value { 558 | middlewares.append(&mut curr_val.middlewares.clone()); 559 | } 560 | 561 | return Some(final_val); 562 | } 563 | } 564 | 565 | continue; 566 | } 567 | 568 | // Not found 569 | None 570 | } 571 | } 572 | 573 | /// Action to be performed by the node 574 | enum ActionName { 575 | NextNode, 576 | CreateNewNode, 577 | SplitNode, 578 | SplitKey, 579 | Error, 580 | } 581 | 582 | /// Action Payload: 583 | /// characters matched for the node key and insert key 584 | /// node index in the node vector 585 | struct ActionPayload { 586 | match_count: usize, 587 | node_index: usize, 588 | } 589 | 590 | /// Container for actions will be performed in the trie 591 | struct Action { 592 | name: ActionName, 593 | payload: ActionPayload, 594 | } 595 | 596 | impl Action { 597 | pub fn new(name: ActionName, payload: ActionPayload) -> Self { 598 | Action { name, payload } 599 | } 600 | } 601 | 602 | impl ActionPayload { 603 | pub fn new(match_count: usize, node_index: usize) -> Self { 604 | ActionPayload { 605 | match_count, 606 | node_index, 607 | } 608 | } 609 | } 610 | 611 | #[cfg(test)] 612 | mod tests { 613 | use super::*; 614 | use crate::context::Context; 615 | use crate::middleware::logger::Logger; 616 | use crate::router::ContextResult; 617 | 618 | async fn handler(ctx: Context) -> ContextResult { 619 | ctx.build("test").ok() 620 | } 621 | 622 | #[test] 623 | fn radix_trie_head_test() { 624 | let mut route_trie = RouteTrie::new(); 625 | let logger = Logger::new(); 626 | 627 | route_trie.insert_default_middleware(logger); 628 | route_trie.insert_route("/", Route::new(Method::GET, handler)); 629 | 630 | let result = route_trie.search_route("/"); 631 | 632 | assert!(result.is_some()); 633 | 634 | match result { 635 | Some(route) => { 636 | let middlewares = route.get_middlewares(); 637 | let route_value = route.get_route(&Method::GET).is_some(); 638 | 639 | assert_eq!(middlewares.len(), 1); 640 | assert!(route_value); 641 | } 642 | _ => panic!(), 643 | } 644 | } 645 | 646 | #[test] 647 | fn radix_trie_normal_test() { 648 | let mut route_trie = RouteTrie::new(); 649 | let logger = Logger::new(); 650 | let logger2 = Logger::new(); 651 | 652 | route_trie.insert_default_middleware(logger); 653 | route_trie.insert_route("/normal/test/", Route::new(Method::GET, handler)); 654 | route_trie.insert_route("/ノーマル/テスト/", Route::new(Method::GET, handler)); 655 | route_trie.insert_middleware("/ノーマル/テスト/", logger2); 656 | 657 | let result = route_trie.search_route("/normal/test/"); 658 | 659 | assert!(result.is_some()); 660 | 661 | match result { 662 | Some(route) => { 663 | let middlewares = route.get_middlewares(); 664 | let route_value = route.get_route(&Method::GET).is_some(); 665 | 666 | assert_eq!(middlewares.len(), 1); 667 | assert!(route_value); 668 | } 669 | _ => panic!(), 670 | } 671 | 672 | let result = route_trie.search_route("/ノーマル/テスト/"); 673 | 674 | assert!(result.is_some()); 675 | 676 | match result { 677 | Some(route) => { 678 | let middlewares = route.get_middlewares(); 679 | let route_value = route.get_route(&Method::GET).is_some(); 680 | 681 | assert_eq!(middlewares.len(), 2); 682 | assert!(route_value); 683 | } 684 | _ => panic!(), 685 | } 686 | } 687 | 688 | #[test] 689 | fn radix_trie_not_found_test() { 690 | let mut route_trie = RouteTrie::new(); 691 | let logger = Logger::new(); 692 | 693 | route_trie.insert_default_middleware(logger); 694 | route_trie.insert_route("/normal/test/", Route::new(Method::GET, handler)); 695 | 696 | let result = route_trie.search_route("/fail/test/"); 697 | 698 | assert!(result.is_none()); 699 | } 700 | 701 | #[test] 702 | fn radix_trie_split_node_and_key_test() { 703 | let mut route_trie = RouteTrie::new(); 704 | let logger = Logger::new(); 705 | let logger2 = Logger::new(); 706 | let logger3 = Logger::new(); 707 | 708 | route_trie.insert_default_middleware(logger); 709 | route_trie.insert_route("/normal/test/", Route::new(Method::GET, handler)); 710 | route_trie.insert_route("/noral/test/", Route::new(Method::GET, handler)); 711 | route_trie.insert_route("/ノーマル/テスト/", Route::new(Method::GET, handler)); 712 | route_trie.insert_route("/ノーマル/テーブル/", Route::new(Method::GET, handler)); 713 | route_trie.insert_middleware("/noral/test/", logger2); 714 | route_trie.insert_middleware("/ノーマル/テーブル/", logger3); 715 | 716 | let test_cases = vec![ 717 | ("/normal/test/", 1), 718 | ("/noral/test/", 2), 719 | ("/ノーマル/テスト/", 1), 720 | ("/ノーマル/テーブル/", 2), 721 | ]; 722 | 723 | for case in test_cases.iter() { 724 | let normal_result = route_trie.search_route(case.0); 725 | 726 | assert!(normal_result.is_some()); 727 | 728 | match normal_result { 729 | Some(route) => { 730 | let middlewares = route.get_middlewares(); 731 | let route_value = route.get_route(&Method::GET).is_some(); 732 | 733 | assert_eq!(middlewares.len(), case.1); 734 | assert!(route_value); 735 | } 736 | _ => panic!(), 737 | } 738 | } 739 | } 740 | 741 | #[test] 742 | fn radix_trie_wildcard_test() { 743 | let mut route_trie = RouteTrie::new(); 744 | let logger = Logger::new(); 745 | let logger2 = Logger::new(); 746 | let logger3 = Logger::new(); 747 | 748 | route_trie.insert_route("/normal/test/*", Route::new(Method::GET, handler)); 749 | route_trie.insert_middleware("/normal/test/*", logger); 750 | route_trie.insert_middleware("/normal/test/*", logger2); 751 | route_trie.insert_middleware("/normal/test/*", logger3); 752 | 753 | let test_cases = vec![ 754 | "/normal/test/test", 755 | "/normal/test/123", 756 | "/normal/test/こんにちは", 757 | "/normal/test/啊", 758 | ]; 759 | 760 | for case in test_cases.iter() { 761 | let normal_result = route_trie.search_route(case); 762 | 763 | assert!(normal_result.is_some()); 764 | 765 | match normal_result { 766 | Some(route) => { 767 | let middlewares = route.get_middlewares(); 768 | let route_value = route.get_route(&Method::GET).is_some(); 769 | 770 | assert_eq!(middlewares.len(), 3); 771 | assert!(route_value); 772 | } 773 | _ => panic!(), 774 | } 775 | } 776 | } 777 | 778 | #[should_panic] 779 | #[test] 780 | fn radix_trie_wildcard_param_conflict_test() { 781 | let mut route_trie = RouteTrie::new(); 782 | 783 | route_trie.insert_route("/normal/test/*", Route::new(Method::GET, handler)); 784 | route_trie.insert_route("/normal/test/:param", Route::new(Method::GET, handler)); 785 | } 786 | 787 | #[should_panic] 788 | #[test] 789 | fn radix_trie_param_wildcard_conflict_test() { 790 | let mut route_trie = RouteTrie::new(); 791 | 792 | route_trie.insert_route("/normal/test/:param", Route::new(Method::GET, handler)); 793 | route_trie.insert_route("/normal/test/*", Route::new(Method::GET, handler)); 794 | } 795 | } 796 | --------------------------------------------------------------------------------