├── .github
├── media
│ └── logo.png
└── workflows
│ └── rust.yml
├── .gitignore
├── .travis.yml
├── CODE_OF_CONDUCT.md
├── Cargo.toml
├── LICENSE
├── README.md
├── examples
├── app_state.rs
├── hello.rs
├── hello_handler.rs
├── json.rs
├── main.rs
└── middleware
│ ├── logger_example.rs
│ └── mod.rs
├── screenshot
└── serve.png
└── src
├── app.rs
├── context.rs
├── error.rs
├── error
└── obsidian_error.rs
├── lib.rs
├── middleware.rs
├── middleware
└── logger.rs
├── router.rs
└── router
├── handler.rs
├── req_deserializer.rs
├── resource.rs
├── responder.rs
├── response.rs
├── response_body.rs
├── route.rs
└── route_trie.rs
/.github/media/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/obsidian-rs/obsidian/bbc25ea05c97cfe621e3000763e78dd4f9cb1fdb/.github/media/logo.png
--------------------------------------------------------------------------------
/.github/workflows/rust.yml:
--------------------------------------------------------------------------------
1 | name: Obsidian Action
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | - develop
8 | - release/*
9 | pull_request:
10 | branches:
11 | - master
12 | - develop
13 |
14 | jobs:
15 | build_stable:
16 | name: Stable
17 | runs-on: ${{ matrix.os }}
18 | strategy:
19 | matrix:
20 | os: [ubuntu-latest, windows-latest, macOS-latest]
21 | rust: [stable]
22 |
23 | steps:
24 | - uses: hecrj/setup-rust-action@v1
25 | with:
26 | rust-version: ${{ matrix.rust }}
27 | components: rustfmt
28 | - uses: actions/checkout@v2
29 | - name: Build
30 | run: cargo build --verbose
31 | - name: Install clippy
32 | run: rustup component add clippy
33 | - name: Check code format
34 | run: cargo fmt --all -- --check
35 | - name: Run clippy
36 | run: cargo clippy --all-targets --all-features -- -D warnings
37 | - name: Run tests
38 | run: cargo test --verbose
39 |
40 | build_nightly:
41 | name: Nightly
42 | runs-on: ${{ matrix.os }}
43 | strategy:
44 | matrix:
45 | os: [ubuntu-latest, windows-latest, macOS-latest]
46 | rust: [nightly]
47 |
48 | steps:
49 | - uses: hecrj/setup-rust-action@v1
50 | with:
51 | rust-version: ${{ matrix.rust }}
52 | - uses: actions/checkout@v2
53 | - name: Build
54 | run: cargo +nightly build --verbose
55 | # - name: Install clippy
56 | # run: rustup component add clippy
57 | # - name: Run clippy
58 | # run: cargo clippy --all-targets --all-features -- -D warnings
59 | - name: Run tests
60 | run: cargo +nightly test --verbose
61 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Generated by Cargo
2 | # will have compiled files and executables
3 | /target/
4 |
5 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
6 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
7 | Cargo.lock
8 |
9 | # These are backup files generated by rustfmt
10 | **/*.rs.bk
11 |
12 | .DS_Store
13 |
14 | # vscode local config
15 | /.vscode/
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: rust
2 | rust:
3 | - stable
4 | - beta
5 | - nightly
6 | branches:
7 | only:
8 | - master
9 | - develop
10 | matrix:
11 | allow_failures:
12 | - rust: nightly
13 | fast_finish: true
14 | cache: cargo
15 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | .
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
130 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = 'obsidian'
3 | version = '0.3.0-dev'
4 | authors = [
5 | 'Gan Jun Kai ',
6 | 'Wai Pai Lee ',
7 | ]
8 | edition = '2018'
9 | description = 'Ergonomic async http framework for amazing, reliable and efficient web'
10 | readme = "README.md"
11 | homepage = "https://obsidian-rs.github.io"
12 | repository = 'https://github.com/obsidian-rs/obsidian'
13 | license = 'MIT'
14 | keywords = [
15 | 'obsidian',
16 | 'async',
17 | 'http',
18 | 'web',
19 | 'framework'
20 | ]
21 | categories = ["asynchronous", "web-programming::http-server", "network-programming"]
22 |
23 | [[example]]
24 | name = 'example'
25 | path = 'examples/main.rs'
26 |
27 | [[example]]
28 | name = 'hello'
29 | path = 'examples/hello.rs'
30 |
31 | [[example]]
32 | name = 'hello_handler'
33 | path = 'examples/hello_handler.rs'
34 |
35 | [[example]]
36 | name = 'json'
37 | path = 'examples/json.rs'
38 |
39 | [[example]]
40 | name = 'app_state'
41 | path = 'examples/app_state.rs'
42 |
43 | [dependencies]
44 | hyper = { version = "0.14.9", features = [ "full" ] }
45 | http = "0.2.4"
46 | serde = { version = "1.0.126", features = [ "derive" ] }
47 | serde_json = "1.0.64"
48 | url = "2.2.2"
49 | async-std = "1.9.0"
50 | tokio = { version = "1.7.0", features = [ "macros", "rt-multi-thread" ] }
51 | async-trait = "0.1.50"
52 | colored = "2.0.0"
53 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Gan Jun Kai & Wai Pai Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Obsidian
7 |
8 |
9 |
10 | Obsidian is an ergonomic Rust async http framework for reliable and efficient web.
11 |
12 |
20 |
21 |
22 |
23 |
24 |
25 | ## Get Started
26 | ```toml
27 | [dependencies]
28 | # add these 2 dependencies in Cargo.toml file
29 | obsidian = "0.2.2"
30 | tokio = "0.2.21"
31 | ```
32 |
33 | ## Hello World
34 |
35 | ```rust
36 | use obsidian::{context::Context, App};
37 |
38 | #[tokio::main]
39 | async fn main() {
40 | let mut app: App = App::new();
41 |
42 | app.get("/", |ctx: Context| async { ctx.build("Hello World").ok() });
43 |
44 | app.listen(3000).await;
45 | }
46 | ```
47 |
48 | ## Hello World (with handler function)
49 |
50 | ```rust
51 | use obsidian::{context::Context, App, ContextResult};
52 |
53 | async fn hello_world(ctx: Context) -> ContextResult {
54 | ctx.build("Hello World").ok()
55 | }
56 |
57 |
58 | #[tokio::main]
59 | async fn main() {
60 | let mut app: App = App::new();
61 |
62 | app.get("/", hello_world);
63 |
64 | app.listen(3000).await;
65 | }
66 | ```
67 |
68 | ## JSON Response
69 |
70 | ```rust
71 | use obsidian::{context::Context, App, ContextResult};
72 | use serde::*;
73 |
74 | async fn get_user(ctx: Context) -> ContextResult {
75 | #[derive(Serialize, Deserialize)]
76 | struct User {
77 | name: String,
78 | };
79 |
80 | let user = User {
81 | name: String::from("Obsidian"),
82 | };
83 | ctx.build_json(user).ok()
84 | }
85 |
86 | #[tokio::main]
87 | async fn main() {
88 | let mut app: App = App::new();
89 |
90 | app.get("/user", get_user);
91 |
92 | app.listen(3000).await;
93 | }
94 |
95 | ```
96 |
97 | ## Example Files
98 |
99 | Example are located in `example/main.rs`.
100 |
101 | ## Run Example
102 |
103 | ```
104 | cargo run --example example
105 | ```
106 |
107 | ## Current State
108 |
109 | NOT READY FOR PRODUCTION YET!
110 |
--------------------------------------------------------------------------------
/examples/app_state.rs:
--------------------------------------------------------------------------------
1 | use obsidian::{context::Context, App, ObsidianError};
2 |
3 | #[derive(Clone)]
4 | pub struct AppState {
5 | pub db_connection_string: String,
6 | }
7 |
8 | #[tokio::main]
9 | async fn main() {
10 | let mut app: App = App::new();
11 |
12 | app.set_app_state(AppState {
13 | db_connection_string: "localhost:1433".to_string(),
14 | });
15 |
16 | app.get("/", |ctx: Context| async {
17 | let app_state = ctx.get::().ok_or(ObsidianError::NoneError)?;
18 | let res = Some(format!(
19 | "connection string: {}",
20 | &app_state.db_connection_string
21 | ));
22 |
23 | ctx.build(res).ok()
24 | });
25 |
26 | app.listen(3000).await;
27 | }
28 |
--------------------------------------------------------------------------------
/examples/hello.rs:
--------------------------------------------------------------------------------
1 | use obsidian::{context::Context, App};
2 |
3 | #[tokio::main]
4 | async fn main() {
5 | let mut app: App = App::new();
6 |
7 | app.get("/", |ctx: Context| async { ctx.build("Hello World").ok() });
8 |
9 | app.listen(3000).await;
10 | }
11 |
--------------------------------------------------------------------------------
/examples/hello_handler.rs:
--------------------------------------------------------------------------------
1 | use obsidian::{context::Context, App, ContextResult};
2 |
3 | async fn hello_world(ctx: Context) -> ContextResult {
4 | ctx.build("Hello World").ok()
5 | }
6 |
7 | #[tokio::main]
8 | async fn main() {
9 | let mut app: App = App::new();
10 |
11 | app.get("/", hello_world);
12 |
13 | app.listen(3000).await;
14 | }
15 |
--------------------------------------------------------------------------------
/examples/json.rs:
--------------------------------------------------------------------------------
1 | use obsidian::{context::Context, App, ContextResult};
2 | use serde::*;
3 |
4 | async fn get_user(ctx: Context) -> ContextResult {
5 | #[derive(Serialize, Deserialize)]
6 | struct User {
7 | name: String,
8 | }
9 |
10 | let user = User {
11 | name: String::from("Obsidian"),
12 | };
13 | ctx.build_json(user).ok()
14 | }
15 |
16 | #[tokio::main]
17 | async fn main() {
18 | let mut app: App = App::new();
19 |
20 | app.get("/user", get_user);
21 |
22 | app.listen(3000).await;
23 | }
24 |
--------------------------------------------------------------------------------
/examples/main.rs:
--------------------------------------------------------------------------------
1 | mod middleware;
2 |
3 | use middleware::logger_example::*;
4 | use serde::*;
5 | use std::{fmt, fmt::Display};
6 |
7 | use obsidian::{
8 | context::Context,
9 | router::{header, Responder, Response, Router},
10 | App, ObsidianError, StatusCode,
11 | };
12 |
13 | // Testing example
14 | #[derive(Serialize, Deserialize, Debug)]
15 | struct Point {
16 | x: i32,
17 | y: i32,
18 | }
19 |
20 | #[derive(Serialize, Deserialize, Debug)]
21 | struct JsonTest {
22 | title: String,
23 | content: String,
24 | }
25 |
26 | #[derive(Serialize, Deserialize, Debug)]
27 | struct ParamTest {
28 | test: Vec,
29 | test2: String,
30 | }
31 |
32 | #[derive(Serialize, Deserialize, Debug)]
33 | struct User {
34 | name: String,
35 | age: i8,
36 | }
37 |
38 | impl Display for JsonTest {
39 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 | write!(f, "{{title: {}, content: {}}}", self.title, self.content)
41 | }
42 | }
43 |
44 | // async fn responder_json(mut ctx: Context) -> impl Responder {
45 | // let person: Person = ctx.json().await?;
46 |
47 | // person.age += 1;
48 |
49 | // Ok(response::json(person))
50 | // }
51 |
52 | // async fn responder_obsidian_error(mut ctx: Context) -> impl Responder {
53 | // let json: JsonTest = ctx.json().await?;
54 | // println!("{}", json);
55 | // Ok(response::json(json, StatusCode::OK))
56 | // }
57 |
58 | // fn responder_with_header(ctx: Context: Context) -> impl Responder {
59 | // let headers = vec![
60 | // ("X-Custom-Header-4", "custom-value-4"),
61 | // ("X-Custom-Header-5", "custom-value-5"),
62 | // ];
63 |
64 | // "here"
65 | // .header("Content-Type", "application/json")
66 | // .header("X-Custom-Header", "custom-value")
67 | // .header("X-Custom-Header-2", "custom-value-2")
68 | // .header("X-Custom-Header-3", "custom-value-3")
69 | // .set_headers(headers)
70 | // .status(StatusCode::CREATED)
71 | // }
72 |
73 | #[tokio::main]
74 | async fn main() {
75 | let mut app: App = App::new();
76 |
77 | app.get("/", |ctx: Context| async {
78 | ctx.build(Response::ok().html(" Hello Obsidian ")).ok()
79 | });
80 |
81 | app.get("/json", |ctx: Context| async {
82 | let point = Point { x: 1, y: 2 };
83 |
84 | ctx.build_json(point)
85 | .with_status(StatusCode::OK)
86 | .with_header(header::AUTHORIZATION, "token")
87 | .with_header_str("X-Custom-Header", "Custom header value")
88 | .ok()
89 | });
90 |
91 | app.get("/user", |mut ctx: Context| async {
92 | #[derive(Serialize, Deserialize, Debug)]
93 | struct QueryString {
94 | id: String,
95 | status: String,
96 | }
97 |
98 | let params = match ctx.query_params::() {
99 | Ok(params) => params,
100 | Err(error) => {
101 | println!("error: {}", error);
102 | QueryString {
103 | id: String::from(""),
104 | status: String::from(""),
105 | }
106 | }
107 | };
108 |
109 | println!("params: {:?}", params);
110 |
111 | ctx.build("").ok()
112 | });
113 |
114 | app.patch("/patch-here", |ctx: Context| async {
115 | ctx.build("Here is patch request").ok()
116 | });
117 |
118 | app.get("/json-with-headers", |ctx: Context| async {
119 | let point = Point { x: 1, y: 2 };
120 |
121 | let custom_headers = vec![
122 | ("X-Custom-Header-1", "Custom header 1"),
123 | ("X-Custom-Header-2", "Custom header 2"),
124 | ("X-Custom-Header-3", "Custom header 3"),
125 | ];
126 |
127 | let standard_headers = vec![
128 | (header::AUTHORIZATION, "token"),
129 | (header::ACCEPT_CHARSET, "utf-8"),
130 | ];
131 |
132 | ctx.build(
133 | Response::created()
134 | .with_headers(standard_headers)
135 | .with_headers_str(custom_headers)
136 | .json(point),
137 | )
138 | .ok()
139 | });
140 |
141 | app.get("/string-with-headers", |ctx: Context| async {
142 | let custom_headers = vec![
143 | ("X-Custom-Header-1", "Custom header 1"),
144 | ("X-Custom-Header-2", "Custom header 2"),
145 | ("X-Custom-Header-3", "Custom header 3"),
146 | ];
147 |
148 | let standard_headers = vec![
149 | (header::AUTHORIZATION, "token"),
150 | (header::ACCEPT_CHARSET, "utf-8"),
151 | ];
152 |
153 | ctx.build("Hello World")
154 | .with_headers(standard_headers)
155 | .with_headers_str(custom_headers)
156 | .ok()
157 | });
158 |
159 | app.get("/empty-body", |ctx: Context| async {
160 | ctx.build(StatusCode::OK).ok()
161 | });
162 |
163 | app.get("/vec", |ctx: Context| async {
164 | ctx.build(vec![1, 2, 3])
165 | .with_status(StatusCode::CREATED)
166 | .ok()
167 | });
168 |
169 | app.get("/String", |ctx: Context| async {
170 | ctx.build("This is a String ".to_string()).ok()
171 | });
172 |
173 | app.get("/test/radix", |ctx: Context| async {
174 | ctx.build("Test radix ".to_string()).ok()
175 | });
176 |
177 | app.get("/team/radix", |ctx: Context| async {
178 | ctx.build("Team radix".to_string()).ok()
179 | });
180 |
181 | app.get("/test/radix2", |ctx: Context| async {
182 | ctx.build("Test radix2 ".to_string()).ok()
183 | });
184 |
185 | app.get("/jsontest", |ctx: Context| async {
186 | ctx.build_file("./testjson.html").await.ok()
187 | });
188 |
189 | app.get("/jsan", |ctx: Context| async {
190 | ctx.build("jsan ".to_string()).ok()
191 | });
192 |
193 | app.get("/test/wildcard/*", |ctx: Context| async move {
194 | let res = format!(
195 | "{} {}",
196 | "Test wildcard ".to_string(),
197 | ctx.uri().path()
198 | );
199 |
200 | ctx.build(res).ok()
201 | });
202 |
203 | app.get("router/test", |ctx: Context| async move {
204 | let result = ctx
205 | .extensions()
206 | .get::()
207 | .ok_or(ObsidianError::NoneError)?;
208 |
209 | dbg!(&result.0);
210 |
211 | let res = Some(format!(
212 | "{} {}",
213 | "router test get ".to_string(),
214 | ctx.uri().path()
215 | ));
216 |
217 | ctx.build(res).ok()
218 | });
219 | app.post("router/test", |ctx: Context| async move {
220 | let res = format!(
221 | "{} {}",
222 | "router test post ".to_string(),
223 | ctx.uri().path()
224 | );
225 |
226 | ctx.build(res).ok()
227 | });
228 | app.put("router/test", |ctx: Context| async move {
229 | let res = format!(
230 | "{} {}",
231 | "router test put ".to_string(),
232 | ctx.uri().path()
233 | );
234 |
235 | ctx.build(res).ok()
236 | });
237 | app.delete("router/test", |ctx: Context| async move {
238 | let res = format!(
239 | "{} {}",
240 | "router test delete ".to_string(),
241 | ctx.uri().path()
242 | );
243 |
244 | ctx.build(res).ok()
245 | });
246 |
247 | app.get("route/diff_route", |ctx: Context| async move {
248 | let res = format!(
249 | "{} {}",
250 | "route diff get ".to_string(),
251 | ctx.uri().path()
252 | );
253 |
254 | ctx.build(res).ok()
255 | });
256 |
257 | app.scope("admin", |router: &mut Router| {
258 | router.get("test", |ctx: Context| async move {
259 | ctx.build("Hello admin test").ok()
260 | });
261 |
262 | router.get("test2", |ctx: Context| async move {
263 | ctx.build("Hello admin test 2").ok()
264 | });
265 | });
266 |
267 | app.scope("form", |router: &mut Router| {
268 | router.get("/formtest", |ctx: Context| async move {
269 | ctx.build_file("/.test.html").await.ok()
270 | });
271 | });
272 |
273 | // form_router.post("/formtest", |mut ctx: Context| async move{
274 | // let param_test: ParamTest = ctx.form().await?;
275 |
276 | // dbg!(¶m_test);
277 |
278 | // Ok(response::json(param_test, StatusCode::OK))
279 | // });
280 |
281 | // param_router.get("/paramtest/:id", |ctx: Context| async move {
282 | // let param_test: i32 = ctx.param("id")?;
283 |
284 | // dbg!(¶m_test);
285 |
286 | // Ok(response::json(param_test, StatusCode::OK))
287 | // });
288 | //
289 | // param_router.get("/paramtest/:id/test", |ctx: Context| async move {
290 | // let mut param_test: i32 = ctx.param("id").unwrap();
291 | // param_test = param_test * 10;
292 |
293 | // dbg!(¶m_test);
294 |
295 | // Ok(response::json(param_test, StatusCode::OK))
296 | // });
297 |
298 | let logger_example = middleware::logger_example::LoggerExample::new();
299 | app.use_service(logger_example);
300 |
301 | app.scope("params", |router: &mut Router| {
302 | router.get("/test-next-wild/*", |ctx: Context| async {
303 | ctx.build("test next wild ".to_string()).ok()
304 | });
305 |
306 | router.get("/*", |ctx: Context| async {
307 | ctx.build(
308 | "404 Not Found "
309 | .to_string()
310 | .with_status(StatusCode::NOT_FOUND),
311 | )
312 | .ok()
313 | });
314 | });
315 |
316 | app.use_static_to("/files/", "/assets/");
317 |
318 | app.listen(3000).await;
319 | }
320 |
--------------------------------------------------------------------------------
/examples/middleware/logger_example.rs:
--------------------------------------------------------------------------------
1 | use async_trait::async_trait;
2 |
3 | #[cfg(debug_assertions)]
4 | use colored::*;
5 |
6 | use obsidian::{context::Context, middleware::Middleware, ContextResult, EndpointExecutor};
7 |
8 | #[derive(Default)]
9 | pub struct LoggerExample {}
10 |
11 | pub struct LoggerExampleData(pub String);
12 |
13 | impl LoggerExample {
14 | #[allow(dead_code)]
15 | pub fn new() -> Self {
16 | LoggerExample {}
17 | }
18 | }
19 |
20 | #[async_trait]
21 | impl Middleware for LoggerExample {
22 | async fn handle<'a>(
23 | &'a self,
24 | mut context: Context,
25 | ep_executor: EndpointExecutor<'a>,
26 | ) -> ContextResult {
27 | #[cfg(debug_assertions)]
28 | println!("{}", "[debug] Inside middleware".blue());
29 |
30 | println!(
31 | "{} {}{}",
32 | context.method(),
33 | context.headers().get("host").unwrap().to_str().unwrap(),
34 | context.uri(),
35 | );
36 |
37 | context.add(LoggerExampleData("This is logger data".to_string()));
38 |
39 | ep_executor.next(context).await
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/examples/middleware/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod logger_example;
2 |
--------------------------------------------------------------------------------
/screenshot/serve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/obsidian-rs/obsidian/bbc25ea05c97cfe621e3000763e78dd4f9cb1fdb/screenshot/serve.png
--------------------------------------------------------------------------------
/src/app.rs:
--------------------------------------------------------------------------------
1 | use std::sync::Arc;
2 |
3 | use colored::*;
4 | use hyper::{
5 | header,
6 | service::{make_service_fn, service_fn},
7 | Body, Request, Response, Server, StatusCode,
8 | };
9 |
10 | use crate::context::Context;
11 | use crate::error::ObsidianError;
12 | use crate::middleware::Middleware;
13 | use crate::router::{ContextResult, Handler, RouteValueResult, Router};
14 |
15 | use crate::middleware::logger::Logger;
16 |
17 | #[derive(Clone)]
18 | pub struct DefaultAppState {}
19 |
20 | pub struct App
21 | where
22 | T: Clone + Send + Sync + 'static,
23 | {
24 | router: Router,
25 | app_state: Option,
26 | }
27 |
28 | impl Default for App
29 | where
30 | T: Clone + Send + Sync + 'static,
31 | {
32 | /// create an `Obsidian` app with default middlwares: [`Logger`]
33 | fn default() -> Self {
34 | let mut app = App {
35 | router: Router::new(),
36 | app_state: None,
37 | };
38 | let logger = Logger::new();
39 | app.use_service(logger);
40 | app
41 | }
42 | }
43 |
44 | impl App
45 | where
46 | T: Clone + Send + Sync + 'static,
47 | {
48 | pub fn new() -> Self {
49 | App {
50 | router: Router::new(),
51 | app_state: None,
52 | }
53 | }
54 |
55 | pub fn get(&mut self, path: &str, handler: impl Handler) {
56 | self.router.get(path, handler);
57 | }
58 |
59 | pub fn post(&mut self, path: &str, handler: impl Handler) {
60 | self.router.post(path, handler);
61 | }
62 |
63 | pub fn put(&mut self, path: &str, handler: impl Handler) {
64 | self.router.put(path, handler);
65 | }
66 |
67 | pub fn patch(&mut self, path: &str, handler: impl Handler) {
68 | self.router.patch(path, handler);
69 | }
70 |
71 | pub fn delete(&mut self, path: &str, handler: impl Handler) {
72 | self.router.delete(path, handler);
73 | }
74 |
75 | /// Register a nested router for the app
76 | ///
77 | /// Example:
78 | /// ```
79 | /// use obsidian::{App, router::Router, context::Context};
80 | ///
81 | /// let mut app: App = App::new();
82 | ///
83 | /// app.scope("admin", |router: &mut Router| {
84 | /// router.get("list", |ctx: Context| async move {
85 | /// ctx.build("Admin list here").ok()
86 | /// });
87 | /// });
88 | /// ```
89 | ///
90 | pub fn scope(&mut self, name: &str, scoped_routes: impl Fn(&mut Router)) {
91 | let mut new_router = Router::new();
92 |
93 | scoped_routes(&mut new_router);
94 | self.use_router(format!("/{}", name).as_ref(), new_router);
95 | }
96 |
97 | /// Apply middleware in the provided route
98 | pub fn use_service_to(&mut self, path: &str, middleware: impl Middleware) {
99 | self.router.use_service_to(path, middleware);
100 | }
101 |
102 | /// Apply middleware in current relative route
103 | pub fn use_service(&mut self, middleware: impl Middleware) {
104 | self.router.use_service(middleware);
105 | }
106 |
107 | /// Apply route handler in current relative route
108 | pub fn use_router(&mut self, path: &str, router: Router) {
109 | self.router.use_router(path, router);
110 | }
111 |
112 | /// Serve static files by the virtual path as the route and directory path as the server file path
113 | pub fn use_static_to(&mut self, virtual_path: &str, dir_path: &str) {
114 | self.router.use_static_to(virtual_path, dir_path);
115 | }
116 |
117 | /// Serve static files by the directory path as the route and server file path
118 | pub fn use_static(&mut self, dir_path: &str) {
119 | self.router.use_static(dir_path);
120 | }
121 |
122 | /// Set app state. The app state must impl Clone.
123 | /// The app state will be passed into endpoint handler context dynamic data.
124 | ///
125 | /// # Example
126 | /// ```
127 | /// use obsidian::App;
128 | ///
129 | /// #[derive(Clone)]
130 | /// struct AppState {
131 | /// db_connection: String,
132 | /// }
133 | ///
134 | /// let mut app: App = App::new();
135 | /// app.set_app_state(AppState{
136 | /// db_connection: "localhost:1433".to_string(),
137 | /// });
138 | /// ```
139 | pub fn set_app_state(&mut self, app_state: T) {
140 | self.app_state = Some(app_state);
141 | }
142 |
143 | pub async fn listen(self, port: u16) {
144 | let app_server: AppServer = AppServer {
145 | router: self.router,
146 | };
147 | let app_state = self.app_state;
148 |
149 | let service = make_service_fn(move |_| {
150 | let server_clone = app_server.clone();
151 | let app_state = app_state.clone();
152 |
153 | async {
154 | Ok::<_, hyper::Error>(service_fn(move |req| {
155 | let route_value = server_clone.router.search_route(req.uri().path());
156 |
157 | AppServer::resolve_endpoint(req, route_value, app_state.clone())
158 | }))
159 | }
160 | });
161 |
162 | let addr = ([127, 0, 0, 1], port).into();
163 | let server = Server::bind(&addr).serve(service);
164 |
165 | let logo = r#"
166 |
167 | .oooooo. oooooooooo. .oooooo..o ooooo oooooooooo. ooooo .o. ooooo ooo
168 | d8P' `Y8b `888' `Y8b d8P' `Y8 `888' `888' `Y8b `888' .888. `888b. `8'
169 | 888 888 888 888 Y88bo. 888 888 888 888 .8"888. 8 `88b. 8
170 | 888 888 888oooo888' `"Y8888o. 888 888 888 888 .8' `888. 8 `88b. 8
171 | 888 888 888 `88b `"Y88b 888 888 888 888 .88ooo8888. 8 `88b.8
172 | `88b d88' 888 .88P oo .d8P 888 888 d88' 888 .8' `888. 8 `888
173 | `Y8bood8P' o888bood8P' 8""88888P' o888o o888bood8P' o888o o88o o8888o o8o `8
174 |
175 | "#;
176 |
177 | println!("{}", logo);
178 |
179 | #[cfg(debug_assertions)]
180 | println!(
181 | " 🚧 {}: dev [{} + {}]",
182 | "Mode".green().bold(),
183 | "unoptimized".red().bold(),
184 | "debuginfo".blue().bold()
185 | );
186 |
187 | #[cfg(not(debug_assertions))]
188 | println!(
189 | " 🚀 {}: release [{}]",
190 | "Mode".green().bold(),
191 | "optimized".green().bold(),
192 | );
193 |
194 | println!(
195 | " 🔧 {}: {}",
196 | "Version".green().bold(),
197 | env!("CARGO_PKG_VERSION")
198 | );
199 |
200 | println!(" 🎉 {}: http://{}\n", "Served at".green().bold(), addr);
201 |
202 | server.await.map_err(|_| println!("Server error")).unwrap();
203 | }
204 | }
205 |
206 | #[derive(Clone)]
207 | struct AppServer {
208 | router: Router,
209 | }
210 |
211 | impl AppServer {
212 | pub async fn resolve_endpoint(
213 | req: Request,
214 | route_value: Option,
215 | app_state: Option,
216 | ) -> Result, hyper::Error>
217 | where
218 | T: Send + Sync + 'static,
219 | {
220 | match route_value {
221 | Some(route_value) => {
222 | let route = match route_value.get_route(req.method()) {
223 | Some(r) => r,
224 | None => return Ok::<_, hyper::Error>(page_not_found()),
225 | };
226 | let middlewares = route_value.get_middlewares();
227 | let params = route_value.get_params();
228 | let mut context = Context::new(req, params);
229 | let executor = EndpointExecutor::new(&route.handler, middlewares);
230 |
231 | if let Some(state) = app_state {
232 | context.add::(state);
233 | }
234 |
235 | let route_result = executor.next(context).await;
236 |
237 | let route_response = match route_result {
238 | Ok(ctx) => {
239 | let mut res = Response::builder();
240 | if let Some(response) = ctx.take_response() {
241 | if let Some(headers) = response.headers() {
242 | if let Some(response_headers) = res.headers_mut() {
243 | headers.iter().for_each(move |(key, value)| {
244 | response_headers
245 | .insert(key, header::HeaderValue::from_static(value));
246 | });
247 | }
248 | }
249 | res.status(response.status()).body(response.body())
250 | } else {
251 | // No response found
252 | res.status(StatusCode::OK).body(Body::from(""))
253 | }
254 | }
255 | Err(err) => {
256 | let body = Body::from(err.to_string());
257 | Response::builder()
258 | .status(StatusCode::INTERNAL_SERVER_ERROR)
259 | .body(body)
260 | }
261 | };
262 |
263 | Ok::<_, hyper::Error>(route_response.unwrap_or_else(|_| {
264 | internal_server_error(ObsidianError::GeneralError(
265 | "Error while constructing response body".to_string(),
266 | ))
267 | }))
268 | }
269 | _ => Ok::<_, hyper::Error>(page_not_found()),
270 | }
271 | }
272 | }
273 |
274 | fn page_not_found() -> Response {
275 | let mut server_response = Response::new(Body::from("404 Not Found"));
276 | *server_response.status_mut() = StatusCode::NOT_FOUND;
277 |
278 | server_response
279 | }
280 |
281 | fn internal_server_error(err: impl std::error::Error) -> Response {
282 | let body = Body::from(err.to_string());
283 | Response::builder()
284 | .status(StatusCode::INTERNAL_SERVER_ERROR)
285 | .body(body)
286 | .unwrap()
287 | }
288 |
289 | pub struct EndpointExecutor<'a> {
290 | pub route_endpoint: &'a Arc,
291 | pub middleware: &'a [Arc],
292 | }
293 |
294 | impl<'a> EndpointExecutor<'a> {
295 | pub fn new(
296 | route_endpoint: &'a Arc,
297 | middleware: &'a [Arc],
298 | ) -> Self {
299 | EndpointExecutor {
300 | route_endpoint,
301 | middleware,
302 | }
303 | }
304 |
305 | pub async fn next(mut self, context: Context) -> ContextResult {
306 | if let Some((current, all_next)) = self.middleware.split_first() {
307 | self.middleware = all_next;
308 | current.handle(context, self).await
309 | } else {
310 | self.route_endpoint.call(context).await
311 | }
312 | }
313 | }
314 |
315 | #[cfg(test)]
316 | mod test {
317 | use super::*;
318 | use crate::context::Context;
319 | use async_std::task;
320 | use hyper::{body, body::Buf, StatusCode};
321 |
322 | #[test]
323 | fn test_app_server_resolve_endpoint() {
324 | task::block_on(async {
325 | let mut router = Router::new();
326 |
327 | router.get("/", |mut ctx: Context| async move {
328 | let body = ctx.take_body();
329 |
330 | let request_body = match body::aggregate(body).await {
331 | Ok(buf) => String::from_utf8(buf.chunk().to_vec()),
332 | _ => {
333 | panic!();
334 | }
335 | };
336 |
337 | assert_eq!(ctx.uri().path(), "/");
338 | assert_eq!(request_body.unwrap(), "test_app_server");
339 | ctx.build("test_app_server").ok()
340 | });
341 |
342 | let app_server = AppServer { router };
343 |
344 | let req_builder = Request::builder();
345 |
346 | let req = req_builder
347 | .uri("/")
348 | .body(Body::from("test_app_server"))
349 | .unwrap();
350 |
351 | let route_value = app_server.router.search_route(req.uri().path());
352 | let actual_response =
353 | AppServer::resolve_endpoint::(req, route_value, None)
354 | .await
355 | .unwrap();
356 |
357 | let mut expected_response = Response::new(Body::from("test_app_server"));
358 | *expected_response.status_mut() = StatusCode::OK;
359 |
360 | assert_eq!(actual_response.status(), expected_response.status());
361 |
362 | let actual_res_body = match body::aggregate(actual_response).await {
363 | Ok(buf) => String::from_utf8(buf.chunk().to_vec()),
364 | _ => panic!(),
365 | };
366 |
367 | let expected_res_body = match body::aggregate(expected_response).await {
368 | Ok(buf) => String::from_utf8(buf.chunk().to_vec()),
369 | _ => panic!(),
370 | };
371 |
372 | assert_eq!(actual_res_body.unwrap(), expected_res_body.unwrap());
373 | })
374 | }
375 | }
376 |
--------------------------------------------------------------------------------
/src/context.rs:
--------------------------------------------------------------------------------
1 | use http::Extensions;
2 | use hyper::{body, body::Buf};
3 | use serde::de::DeserializeOwned;
4 | use serde::ser::Serialize;
5 | use url::form_urlencoded;
6 |
7 | use std::borrow::Cow;
8 | use std::collections::HashMap;
9 | use std::convert::From;
10 | use std::str::FromStr;
11 |
12 | use crate::router::{from_cow_map, ContextResult, Responder, Response};
13 | use crate::ObsidianError;
14 | use crate::{
15 | header::{HeaderName, HeaderValue},
16 | Body, HeaderMap, Method, Request, StatusCode, Uri,
17 | };
18 |
19 | /// Context contains the data for current http connection context.
20 | /// For example, request information, params, method, and path.
21 | #[derive(Debug)]
22 | pub struct Context {
23 | request: Request,
24 | params_data: HashMap,
25 | response: Option,
26 | }
27 |
28 | impl Context {
29 | pub fn new(request: Request, params_data: HashMap) -> Self {
30 | Self {
31 | request,
32 | params_data,
33 | response: None,
34 | }
35 | }
36 |
37 | /// Access request headers
38 | pub fn headers(&self) -> &HeaderMap {
39 | self.request.headers()
40 | }
41 |
42 | /// Access mutable request header
43 | pub fn headers_mut(&mut self) -> &mut HeaderMap {
44 | self.request.headers_mut()
45 | }
46 |
47 | /// Access request method
48 | pub fn method(&self) -> &Method {
49 | self.request.method()
50 | }
51 |
52 | /// Access request uri
53 | pub fn uri(&self) -> &Uri {
54 | self.request.uri()
55 | }
56 |
57 | /// Access request extensions
58 | pub fn extensions(&self) -> &Extensions {
59 | self.request.extensions()
60 | }
61 |
62 | /// Access mutable request extensions
63 | pub fn extensions_mut(&mut self) -> &mut Extensions {
64 | self.request.extensions_mut()
65 | }
66 |
67 | /// Add dynamic data into request extensions
68 | pub fn add(&mut self, ctx_data: T) {
69 | self.extensions_mut().insert(ctx_data);
70 | }
71 |
72 | /// Get dynamic data from request extensions
73 | pub fn get(&self) -> Option<&T> {
74 | self.extensions().get::()
75 | }
76 |
77 | /// Get mutable dynamic data from request extensions
78 | pub fn get_mut(&mut self) -> Option<&mut T> {
79 | self.extensions_mut().get_mut::()
80 | }
81 |
82 | /// Method to get the params value according to key.
83 | /// Panic if key is not found.
84 | ///
85 | /// # Example
86 | ///
87 | /// ```
88 | /// # use obsidian::StatusCode;
89 | /// # use obsidian::ContextResult;
90 | /// # use obsidian::context::Context;
91 | ///
92 | /// // Assumming ctx contains params for id and mode
93 | /// async fn get_handler(ctx: Context) -> ContextResult {
94 | /// let id: i32 = ctx.param("id")?;
95 | /// let mode: String = ctx.param("mode")?;
96 | ///
97 | /// assert_eq!(id, 1);
98 | /// assert_eq!(mode, "edit".to_string());
99 | ///
100 | /// ctx.build("").ok()
101 | /// }
102 | ///
103 | /// ```
104 | ///
105 | pub fn param(&self, key: &str) -> Result {
106 | self.params_data
107 | .get(key)
108 | .ok_or(ObsidianError::NoneError)?
109 | .parse()
110 | .map_err(|_err| ObsidianError::ParamError(format!("Failed to parse param {}", key)))
111 | }
112 |
113 | /// Method to get the string query data from the request url.
114 | /// Untagged is not supported
115 | ///
116 | /// # Example
117 | /// ```
118 | /// # use serde::*;
119 | ///
120 | /// # use obsidian::context::Context;
121 | /// # use obsidian::{ContextResult, StatusCode};
122 | ///
123 | /// #[derive(Deserialize, Serialize, Debug)]
124 | /// struct QueryString {
125 | /// id: i32,
126 | /// mode: String,
127 | /// }
128 | ///
129 | /// // Assume ctx contains string query with data {id=1&mode=edit}
130 | /// async fn get_handler(mut ctx: Context) -> ContextResult {
131 | /// let result: QueryString = ctx.query_params()?;
132 | ///
133 | /// assert_eq!(result.id, 1);
134 | /// assert_eq!(result.mode, "edit".to_string());
135 | ///
136 | /// ctx.build("").ok()
137 | /// }
138 | /// ```
139 | pub fn query_params(&mut self) -> Result {
140 | let query = match self.uri().query() {
141 | Some(query) => query,
142 | _ => "",
143 | }
144 | .as_bytes();
145 |
146 | Self::parse_queries(&query)
147 | }
148 |
149 | /// Method to get the forms query data from the request body.
150 | /// Body is consumed after calling this method.
151 | /// Untagged is not supported
152 | ///
153 | /// # Example
154 | /// ```
155 | /// # use serde::*;
156 | ///
157 | /// # use obsidian::context::Context;
158 | /// # use obsidian::{ContextResult, StatusCode};
159 | ///
160 | /// #[derive(Deserialize, Serialize, Debug)]
161 | /// struct FormResult {
162 | /// id: i32,
163 | /// mode: String,
164 | /// }
165 | ///
166 | /// // Assume ctx contains form query with data {id=1&mode=edit}
167 | /// async fn get_handler(mut ctx: Context) -> ContextResult {
168 | /// let result: FormResult = ctx.form().await?;
169 | ///
170 | /// assert_eq!(result.id, 1);
171 | /// assert_eq!(result.mode, "edit".to_string());
172 | ///
173 | /// ctx.build("").ok()
174 | /// }
175 | /// ```
176 | pub async fn form(&mut self) -> Result {
177 | let body = self.take_body();
178 |
179 | let buf = match body::aggregate(body).await {
180 | Ok(buf) => buf,
181 | _ => {
182 | return Err(ObsidianError::NoneError);
183 | }
184 | };
185 |
186 | Self::parse_queries(buf.chunk())
187 | }
188 |
189 | /// Form value merge with Params
190 | pub fn form_with_param(&mut self) -> Result {
191 | unimplemented!()
192 | }
193 |
194 | /// Method to get the json data from the request body. Body is consumed after calling this method.
195 | /// The result can be either handled by using static type or dynamic map.
196 | /// Panic if parsing fail.
197 | ///
198 | /// # Example
199 | ///
200 | /// ### Handle by static type
201 | /// ```
202 | /// # use serde::*;
203 | ///
204 | /// # use obsidian::context::Context;
205 | /// # use obsidian::{ContextResult, StatusCode};
206 | ///
207 | /// #[derive(Deserialize, Serialize, Debug)]
208 | /// struct JsonResult {
209 | /// id: i32,
210 | /// mode: String,
211 | /// }
212 | ///
213 | /// // Assume ctx contains json with data {id:1, mode:'edit'}
214 | /// async fn get_handler(mut ctx: Context) -> ContextResult {
215 | /// let result: JsonResult = ctx.json().await?;
216 | ///
217 | /// assert_eq!(result.id, 1);
218 | /// assert_eq!(result.mode, "edit".to_string());
219 | ///
220 | /// ctx.build("").ok()
221 | /// }
222 | /// ```
223 | ///
224 | /// ### Handle by dynamic map
225 | /// ```
226 | /// # use serde_json::Value;
227 | ///
228 | /// # use obsidian::context::Context;
229 | /// # use obsidian::{ContextResult, StatusCode};
230 | ///
231 | /// // Assume ctx contains json with data {id:1, mode:'edit'}
232 | /// async fn get_handler(mut ctx: Context) -> ContextResult {
233 | /// let result: serde_json::Value = ctx.json().await?;
234 | ///
235 | /// assert_eq!(result["id"], 1);
236 | /// assert_eq!(result["mode"], "edit".to_string());
237 | ///
238 | /// ctx.build("").ok()
239 | /// }
240 | /// ```
241 | pub async fn json(&mut self) -> Result {
242 | let body = self.take_body();
243 |
244 | let buf = match body::aggregate(body).await {
245 | Ok(buf) => buf,
246 | _ => {
247 | return Err(ObsidianError::NoneError);
248 | }
249 | };
250 |
251 | Ok(serde_json::from_slice(buf.chunk())?)
252 | }
253 |
254 | /// Json value merged with Params
255 | pub fn json_with_param(&mut self) -> Result {
256 | unimplemented!()
257 | }
258 |
259 | /// Consumes body of the request and replace it with empty body.
260 | pub fn take_body(&mut self) -> Body {
261 | std::mem::replace(self.request.body_mut(), Body::empty())
262 | }
263 |
264 | /// Take response
265 | pub fn take_response(self) -> Option {
266 | self.response
267 | }
268 |
269 | pub fn response(&self) -> &Option {
270 | &self.response
271 | }
272 |
273 | pub fn response_mut(&mut self) -> &mut Option {
274 | &mut self.response
275 | }
276 |
277 | /// Build any kind of response which implemented Responder trait
278 | pub fn build(self, res: impl Responder) -> ResponseBuilder {
279 | ResponseBuilder::new(self, res.respond_to())
280 | }
281 |
282 | /// Build data into json format. The data must implement Serialize trait
283 | pub fn build_json(self, body: impl Serialize) -> ResponseBuilder {
284 | ResponseBuilder::new(self, Response::ok().json(body))
285 | }
286 |
287 | /// Build response from static file.
288 | pub async fn build_file(self, file_path: &str) -> ResponseBuilder {
289 | ResponseBuilder::new(self, Response::ok().file(file_path).await)
290 | }
291 |
292 | fn parse_queries(query: &[u8]) -> Result {
293 | let mut parsed_form_map: HashMap> = HashMap::default();
294 | let mut cow_form_map = HashMap::, Cow<[String]>>::default();
295 |
296 | // Parse and merge chunks with same name key
297 | form_urlencoded::parse(query)
298 | .into_owned()
299 | .for_each(|(key, val)| {
300 | if !val.is_empty() {
301 | parsed_form_map
302 | .entry(key)
303 | .or_insert_with(Vec::new)
304 | .push(val);
305 | }
306 | });
307 |
308 | // Wrap vec with cow pointer
309 | parsed_form_map.iter().for_each(|(key, val)| {
310 | cow_form_map
311 | .entry(std::borrow::Cow::from(key))
312 | .or_insert_with(|| std::borrow::Cow::from(val));
313 | });
314 |
315 | Ok(from_cow_map(&cow_form_map)?)
316 | }
317 | }
318 |
319 | pub struct ResponseBuilder {
320 | ctx: Context,
321 | response: Response,
322 | }
323 |
324 | impl ResponseBuilder {
325 | pub fn new(ctx: Context, response: Response) -> Self {
326 | ResponseBuilder { ctx, response }
327 | }
328 |
329 | /// set http status code for response
330 | pub fn with_status(mut self, status: StatusCode) -> Self {
331 | self.response = self.response.set_status(status);
332 | self
333 | }
334 |
335 | /// set http header for response
336 | pub fn with_header(mut self, key: HeaderName, value: &'static str) -> Self {
337 | self.response = self.response.set_header(key, value);
338 | self
339 | }
340 |
341 | /// set custom http header for response with `&str` key
342 | pub fn with_header_str(mut self, key: &'static str, value: &'static str) -> Self {
343 | self.response = self.response.set_header_str(key, value);
344 | self
345 | }
346 |
347 | pub fn with_headers(mut self, headers: Vec<(HeaderName, &'static str)>) -> Self {
348 | self.response = self.response.set_headers(headers);
349 | self
350 | }
351 |
352 | pub fn with_headers_str(mut self, headers: Vec<(&'static str, &'static str)>) -> Self {
353 | self.response = self.response.set_headers_str(headers);
354 | self
355 | }
356 |
357 | pub fn ok(mut self) -> ContextResult {
358 | *self.ctx.response_mut() = Some(self.response);
359 | Ok(self.ctx)
360 | }
361 | }
362 |
363 | #[cfg(test)]
364 | mod test {
365 | use super::*;
366 | use async_std::task;
367 | use hyper::{Body, Request};
368 | use serde::*;
369 | use serde_json::json;
370 |
371 | #[derive(Deserialize, Serialize, Debug, PartialEq)]
372 | struct FormResult {
373 | id: i32,
374 | mode: String,
375 | }
376 |
377 | #[derive(Deserialize, Serialize, Debug, PartialEq)]
378 | struct FormExtraResult {
379 | id: i32,
380 | mode: String,
381 | #[serde(default)]
382 | extra: i32,
383 | }
384 |
385 | #[derive(Deserialize, Serialize, Debug, PartialEq)]
386 | struct JsonResult {
387 | id: i32,
388 | mode: String,
389 | }
390 |
391 | #[derive(Deserialize, Serialize, Debug, PartialEq)]
392 | struct JsonExtraResult {
393 | id: i32,
394 | mode: String,
395 | #[serde(default)]
396 | extra: i32,
397 | }
398 |
399 | #[test]
400 | fn test_params() -> Result<(), ObsidianError> {
401 | let mut params_map = HashMap::default();
402 |
403 | params_map.insert("id".to_string(), "1".to_string());
404 | params_map.insert("mode".to_string(), "edit".to_string());
405 |
406 | let request = Request::new(Body::from(""));
407 |
408 | let ctx = Context::new(request, params_map);
409 |
410 | let id: i32 = ctx.param("id")?;
411 | let mode: String = ctx.param("mode")?;
412 |
413 | assert_eq!(id, 1);
414 | assert_eq!(mode, "edit".to_string());
415 |
416 | Ok(())
417 | }
418 |
419 | #[test]
420 | #[should_panic]
421 | fn test_params_without_value() {
422 | let mut params_map = HashMap::default();
423 |
424 | params_map.insert("mode".to_string(), "edit".to_string());
425 |
426 | let request = Request::new(Body::from(""));
427 |
428 | let ctx = Context::new(request, params_map);
429 |
430 | let _mode: String = ctx.param("mode").unwrap();
431 | let _id: i32 = ctx.param("id").unwrap();
432 | }
433 |
434 | #[test]
435 | fn test_string_query() -> Result<(), ObsidianError> {
436 | let params_map = HashMap::default();
437 |
438 | let mut request = Request::new(Body::from(""));
439 | *request.uri_mut() = Uri::from_str("/test/test?id=1&mode=edit").unwrap();
440 |
441 | let mut ctx = Context::new(request, params_map);
442 |
443 | let actual_result: FormResult = ctx.query_params()?;
444 | let expected_result = FormResult {
445 | id: 1,
446 | mode: "edit".to_string(),
447 | };
448 |
449 | assert_eq!(actual_result, expected_result);
450 | Ok(())
451 | }
452 |
453 | #[test]
454 | fn test_form() -> Result<(), ObsidianError> {
455 | task::block_on(async {
456 | let params = HashMap::default();
457 | let request = Request::new(Body::from("id=1&mode=edit"));
458 |
459 | let mut ctx = Context::new(request, params);
460 |
461 | let actual_result: FormResult = ctx.form().await?;
462 | let expected_result = FormResult {
463 | id: 1,
464 | mode: "edit".to_string(),
465 | };
466 |
467 | assert_eq!(actual_result, expected_result);
468 | Ok(())
469 | })
470 | }
471 |
472 | #[test]
473 | fn test_form_with_extra_body() -> Result<(), ObsidianError> {
474 | task::block_on(async {
475 | let params = HashMap::default();
476 | let request = Request::new(Body::from("id=1&mode=edit&extra=true"));
477 |
478 | let mut ctx = Context::new(request, params);
479 |
480 | let actual_result: FormResult = ctx.form().await?;
481 | let expected_result = FormResult {
482 | id: 1,
483 | mode: "edit".to_string(),
484 | };
485 |
486 | assert_eq!(actual_result, expected_result);
487 | Ok(())
488 | })
489 | }
490 |
491 | #[test]
492 | fn test_form_with_extra_field() -> Result<(), ObsidianError> {
493 | task::block_on(async {
494 | let params = HashMap::default();
495 | let request = Request::new(Body::from("id=1&mode=edit"));
496 |
497 | let mut ctx = Context::new(request, params);
498 |
499 | let actual_result: FormExtraResult = ctx.form().await?;
500 | let expected_result = FormExtraResult {
501 | id: 1,
502 | mode: "edit".to_string(),
503 | extra: i32::default(),
504 | };
505 |
506 | assert_eq!(actual_result, expected_result);
507 | Ok(())
508 | })
509 | }
510 |
511 | #[test]
512 | fn test_json_struct() -> Result<(), ObsidianError> {
513 | task::block_on(async {
514 | let params = HashMap::default();
515 | let request = Request::new(Body::from("{\"id\":1,\"mode\":\"edit\"}"));
516 |
517 | let mut ctx = Context::new(request, params);
518 |
519 | let actual_result: JsonResult = ctx.json().await?;
520 | let expected_result = JsonResult {
521 | id: 1,
522 | mode: "edit".to_string(),
523 | };
524 |
525 | assert_eq!(actual_result, expected_result);
526 | Ok(())
527 | })
528 | }
529 |
530 | #[test]
531 | fn test_json_value() -> Result<(), ObsidianError> {
532 | task::block_on(async {
533 | let params = HashMap::default();
534 | let request = Request::new(Body::from("{\"id\":1,\"mode\":\"edit\"}"));
535 |
536 | let mut ctx = Context::new(request, params);
537 |
538 | let actual_result: serde_json::Value = ctx.json().await?;
539 |
540 | assert_eq!(actual_result["id"], json!(1));
541 | assert_eq!(actual_result["mode"], json!("edit"));
542 | Ok(())
543 | })
544 | }
545 | }
546 |
--------------------------------------------------------------------------------
/src/error.rs:
--------------------------------------------------------------------------------
1 | mod obsidian_error;
2 |
3 | pub use obsidian_error::ObsidianError;
4 |
--------------------------------------------------------------------------------
/src/error/obsidian_error.rs:
--------------------------------------------------------------------------------
1 | use std::error::Error;
2 | use std::fmt;
3 | use std::fmt::Display;
4 |
5 | use serde_json::error::Error as JsonError;
6 |
7 | use crate::router::FormError;
8 |
9 | /// Errors occurs in Obsidian framework
10 | #[derive(Debug)]
11 | pub enum ObsidianError {
12 | ParamError(String),
13 | JsonError(JsonError),
14 | FormError(FormError),
15 | GeneralError(String),
16 | NoneError,
17 | }
18 |
19 | impl Display for ObsidianError {
20 | fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
21 | let error_msg = match *self {
22 | ObsidianError::ParamError(ref msg) => msg.to_string(),
23 | ObsidianError::JsonError(ref err) => err.to_string(),
24 | ObsidianError::FormError(ref err) => err.to_string(),
25 | ObsidianError::GeneralError(ref msg) => msg.to_string(),
26 | ObsidianError::NoneError => "Input should not be None".to_string(),
27 | };
28 |
29 | formatter.write_str(&error_msg)
30 | }
31 | }
32 |
33 | impl From for ObsidianError {
34 | fn from(error: FormError) -> Self {
35 | ObsidianError::FormError(error)
36 | }
37 | }
38 |
39 | impl From for ObsidianError {
40 | fn from(error: JsonError) -> Self {
41 | ObsidianError::JsonError(error)
42 | }
43 | }
44 |
45 | impl Error for ObsidianError {
46 | fn description(&self) -> &str {
47 | "Obsidian Error"
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | //#[deny(missing_docs)]
2 |
3 | mod app;
4 | pub mod error;
5 |
6 | pub mod context;
7 | pub mod middleware;
8 | pub mod router;
9 |
10 | pub use app::{App, EndpointExecutor};
11 | pub use error::ObsidianError;
12 | pub use hyper::{header, Body, HeaderMap, Method, Request, Response, StatusCode, Uri, Version};
13 | pub use router::ContextResult;
14 |
--------------------------------------------------------------------------------
/src/middleware.rs:
--------------------------------------------------------------------------------
1 | pub mod logger;
2 |
3 | use async_trait::async_trait;
4 |
5 | use crate::app::EndpointExecutor;
6 | use crate::context::Context;
7 | use crate::router::ContextResult;
8 |
9 | #[async_trait]
10 | pub trait Middleware: Send + Sync + 'static {
11 | async fn handle<'a>(
12 | &'a self,
13 | context: Context,
14 | ep_executor: EndpointExecutor<'a>,
15 | ) -> ContextResult;
16 | }
17 |
--------------------------------------------------------------------------------
/src/middleware/logger.rs:
--------------------------------------------------------------------------------
1 | use async_trait::async_trait;
2 | use std::time::Instant;
3 |
4 | use crate::app::EndpointExecutor;
5 | use crate::context::Context;
6 | use crate::middleware::Middleware;
7 | use crate::router::ContextResult;
8 |
9 | use colored::*;
10 |
11 | #[derive(Default)]
12 | pub struct Logger {}
13 |
14 | impl Logger {
15 | pub fn new() -> Self {
16 | Logger {}
17 | }
18 | }
19 |
20 | #[async_trait]
21 | impl Middleware for Logger {
22 | async fn handle<'a>(
23 | &'a self,
24 | context: Context,
25 | ep_executor: EndpointExecutor<'a>,
26 | ) -> ContextResult {
27 | let start = Instant::now();
28 | println!("[info] {} {}", context.method(), context.uri(),);
29 |
30 | #[cfg(debug_assertions)]
31 | println!("{} {:#?}", "[debug]".cyan(), context);
32 |
33 | match ep_executor.next(context).await {
34 | Ok(context_after) => {
35 | let duration = start.elapsed();
36 | println!(
37 | "[info] Sent {} in {:?}",
38 | context_after.response().as_ref().unwrap().status(),
39 | duration
40 | );
41 | Ok(context_after)
42 | }
43 | Err(error) => {
44 | println!("{} {}", "[error]".red(), error);
45 | Err(error)
46 | }
47 | }
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/src/router.rs:
--------------------------------------------------------------------------------
1 | mod handler;
2 | mod req_deserializer;
3 | mod resource;
4 | mod responder;
5 | mod response;
6 | mod response_body;
7 | mod route;
8 | mod route_trie;
9 |
10 | use self::route_trie::RouteTrie;
11 | use crate::context::Context;
12 | use crate::middleware::Middleware;
13 | use crate::Method;
14 | pub use hyper::header;
15 |
16 | pub use self::handler::{ContextResult, Handler};
17 | pub use self::req_deserializer::{from_cow_map, Error as FormError};
18 | pub use self::resource::Resource;
19 | pub use self::responder::Responder;
20 | pub use self::response::Response;
21 | pub use self::response_body::ResponseBody;
22 | pub use self::route::Route;
23 |
24 | pub(crate) use self::route_trie::RouteValueResult;
25 |
26 | pub struct Router {
27 | routes: RouteTrie,
28 | }
29 |
30 | impl Clone for Router {
31 | fn clone(&self) -> Self {
32 | Router {
33 | routes: self.routes.clone(),
34 | }
35 | }
36 | }
37 |
38 | impl Default for Router {
39 | fn default() -> Self {
40 | Self::new()
41 | }
42 | }
43 |
44 | impl Router {
45 | pub fn new() -> Self {
46 | Router {
47 | routes: RouteTrie::new(),
48 | }
49 | }
50 |
51 | pub fn get(&mut self, path: &str, handler: impl Handler) {
52 | self.insert_route(Method::GET, path, handler);
53 | }
54 |
55 | pub fn post(&mut self, path: &str, handler: impl Handler) {
56 | self.insert_route(Method::POST, path, handler);
57 | }
58 |
59 | pub fn put(&mut self, path: &str, handler: impl Handler) {
60 | self.insert_route(Method::PUT, path, handler);
61 | }
62 |
63 | pub fn patch(&mut self, path: &str, handler: impl Handler) {
64 | self.insert_route(Method::PATCH, path, handler);
65 | }
66 |
67 | pub fn delete(&mut self, path: &str, handler: impl Handler) {
68 | self.insert_route(Method::DELETE, path, handler);
69 | }
70 |
71 | /// Apply middleware in the provided route
72 | pub fn use_service_to(&mut self, path: &str, middleware: impl Middleware) {
73 | self.routes.insert_middleware(path, middleware);
74 | }
75 |
76 | /// Apply middleware in current relative route
77 | pub fn use_service(&mut self, middleware: impl Middleware) {
78 | self.routes.insert_default_middleware(middleware);
79 | }
80 |
81 | /// Serve static files by the virtual path as the route and directory path as the server file path
82 | pub fn use_static_to(&mut self, virtual_path: &str, dir_path: &str) {
83 | let mut path = String::from(virtual_path);
84 | path.push_str("/*");
85 |
86 | self.get(
87 | &path,
88 | Self::static_virtual_file_handler(virtual_path, dir_path),
89 | );
90 | }
91 |
92 | /// Serve static files by the directory path as the route and server file path
93 | pub fn use_static(&mut self, dir_path: &str) {
94 | let mut path = String::from(dir_path);
95 | path.push_str("/*");
96 |
97 | self.get(&path, Self::static_dir_file_handler);
98 | }
99 |
100 | /// Apply route handler in current relative route
101 | pub fn use_router(&mut self, path: &str, other: Router) {
102 | RouteTrie::insert_sub_route(&mut self.routes, path, other.routes);
103 | }
104 |
105 | pub fn search_route(&self, path: &str) -> Option {
106 | self.routes.search_route(path)
107 | }
108 |
109 | fn insert_route(&mut self, method: Method, path: &str, handler: impl Handler) {
110 | let route = Route::new(method, handler);
111 |
112 | self.routes.insert_route(path, route);
113 | }
114 |
115 | fn static_virtual_file_handler(virtual_path: &str, dir_path: &str) -> impl Handler {
116 | let dir_path = dir_path
117 | .split('/')
118 | .filter(|key| !key.is_empty())
119 | .map(|x| x.to_string())
120 | .collect::>();
121 |
122 | let virtual_path_len = virtual_path
123 | .split('/')
124 | .filter(|key| !key.is_empty())
125 | .count();
126 |
127 | move |ctx: Context| {
128 | let mut dir_path = dir_path.clone();
129 | let mut relative_path = ctx
130 | .uri()
131 | .path()
132 | .split('/')
133 | .filter(|key| !key.is_empty())
134 | .skip(virtual_path_len)
135 | .map(|x| x.to_string())
136 | .collect::>();
137 |
138 | dir_path.append(&mut relative_path);
139 |
140 | Box::pin(async move {
141 | ctx.build(Response::ok().file(&dir_path.join("/")).await)
142 | .ok()
143 | })
144 | }
145 | }
146 |
147 | async fn static_dir_file_handler(ctx: Context) -> ContextResult {
148 | let relative_path = ctx
149 | .uri()
150 | .path()
151 | .split('/')
152 | .filter(|key| !key.is_empty())
153 | .map(|x| x.to_string())
154 | .collect::>();
155 |
156 | ctx.build(Response::ok().file(&relative_path.join("/")).await)
157 | .ok()
158 | }
159 | }
160 |
161 | #[cfg(test)]
162 | mod tests {
163 | use super::*;
164 | use crate::context::Context;
165 | use crate::middleware::logger::Logger;
166 |
167 | async fn handler(ctx: Context) -> ContextResult {
168 | ctx.build("test").ok()
169 | }
170 |
171 | #[test]
172 | fn router_get_test() {
173 | let mut router = Router::new();
174 |
175 | router.get("router/test", handler);
176 |
177 | let result = router.search_route("router/test");
178 | let fail_result = router.search_route("failed");
179 |
180 | assert!(result.is_some());
181 | assert!(fail_result.is_none());
182 |
183 | match result {
184 | Some(route) => {
185 | let middlewares = route.get_middlewares();
186 | let route_value = route.get_route(&Method::GET).unwrap();
187 |
188 | assert_eq!(middlewares.len(), 0);
189 | assert_eq!(route_value.method, Method::GET);
190 | }
191 | _ => panic!(),
192 | }
193 | }
194 |
195 | #[test]
196 | fn router_post_test() {
197 | let mut router = Router::new();
198 |
199 | router.post("router/test", handler);
200 |
201 | let result = router.search_route("router/test");
202 | let fail_result = router.search_route("failed");
203 |
204 | assert!(result.is_some());
205 | assert!(fail_result.is_none());
206 |
207 | match result {
208 | Some(route) => {
209 | let middlewares = route.get_middlewares();
210 | let route_value = route.get_route(&Method::POST).unwrap();
211 |
212 | assert_eq!(middlewares.len(), 0);
213 | assert_eq!(route_value.method, Method::POST);
214 | }
215 | _ => panic!(),
216 | }
217 | }
218 |
219 | #[test]
220 | fn router_put_test() {
221 | let mut router = Router::new();
222 |
223 | router.put("router/test", handler);
224 |
225 | let result = router.search_route("router/test");
226 | let fail_result = router.search_route("failed");
227 |
228 | assert!(result.is_some());
229 | assert!(fail_result.is_none());
230 |
231 | match result {
232 | Some(route) => {
233 | let middlewares = route.get_middlewares();
234 | let route_value = route.get_route(&Method::PUT).unwrap();
235 |
236 | assert_eq!(middlewares.len(), 0);
237 | assert_eq!(route_value.method, Method::PUT);
238 | }
239 | _ => panic!(),
240 | }
241 | }
242 |
243 | #[test]
244 | fn router_delete_test() {
245 | let mut router = Router::new();
246 |
247 | router.delete("router/test", handler);
248 |
249 | let result = router.search_route("router/test");
250 | let fail_result = router.search_route("failed");
251 |
252 | assert!(result.is_some());
253 | assert!(fail_result.is_none());
254 |
255 | match result {
256 | Some(route) => {
257 | let middlewares = route.get_middlewares();
258 | let route_value = route.get_route(&Method::DELETE).unwrap();
259 |
260 | assert_eq!(middlewares.len(), 0);
261 | assert_eq!(route_value.method, Method::DELETE);
262 | }
263 | _ => panic!(),
264 | }
265 | }
266 |
267 | #[test]
268 | fn router_root_middleware_test() {
269 | let mut router = Router::new();
270 | let logger = Logger::new();
271 |
272 | router.use_service(logger);
273 |
274 | let result = router.search_route("/");
275 | let fail_result = router.search_route("failed");
276 |
277 | assert!(result.is_some());
278 | assert!(fail_result.is_none());
279 |
280 | match result {
281 | Some(route) => {
282 | let middlewares = route.get_middlewares();
283 |
284 | assert_eq!(middlewares.len(), 1);
285 | }
286 | _ => panic!(),
287 | }
288 | }
289 |
290 | #[test]
291 | fn router_relative_middleware_test() {
292 | let mut router = Router::new();
293 | let logger = Logger::new();
294 |
295 | router.use_service_to("middleware/child", logger);
296 |
297 | let result = router.search_route("/middleware/child");
298 | let fail_result = router.search_route("/");
299 |
300 | assert!(result.is_some());
301 | assert!(fail_result.is_none());
302 |
303 | match result {
304 | Some(route) => {
305 | let middlewares = route.get_middlewares();
306 |
307 | assert_eq!(middlewares.len(), 1);
308 | }
309 | _ => panic!(),
310 | }
311 | }
312 |
313 | #[test]
314 | fn router_search_test() {
315 | let mut router = Router::new();
316 |
317 | router.get("router/test", handler);
318 | router.post("router/test", handler);
319 | router.put("router/test", handler);
320 | router.delete("router/test", handler);
321 |
322 | router.get("route/diff_route", handler);
323 |
324 | let result = router.search_route("router/test");
325 | let diff_result = router.search_route("route/diff_route");
326 | let fail_result = router.search_route("failed");
327 |
328 | assert!(result.is_some());
329 | assert!(diff_result.is_some());
330 | assert!(fail_result.is_none());
331 |
332 | match result {
333 | Some(route) => {
334 | let middlewares = route.get_middlewares();
335 | let route_value = route.get_route(&Method::GET).unwrap();
336 |
337 | assert_eq!(middlewares.len(), 0);
338 | assert_eq!(route_value.method, Method::GET);
339 |
340 | let route_value = route.get_route(&Method::POST).unwrap();
341 |
342 | assert_eq!(middlewares.len(), 0);
343 | assert_eq!(route_value.method, Method::POST);
344 |
345 | let route_value = route.get_route(&Method::PUT).unwrap();
346 |
347 | assert_eq!(middlewares.len(), 0);
348 | assert_eq!(route_value.method, Method::PUT);
349 |
350 | let route_value = route.get_route(&Method::DELETE).unwrap();
351 |
352 | assert_eq!(middlewares.len(), 0);
353 | assert_eq!(route_value.method, Method::DELETE);
354 | }
355 | _ => panic!(),
356 | }
357 |
358 | match diff_result {
359 | Some(route) => {
360 | let middlewares = route.get_middlewares();
361 | let route_value = route.get_route(&Method::GET).unwrap();
362 |
363 | assert_eq!(middlewares.len(), 0);
364 | assert_eq!(route_value.method, Method::GET);
365 | }
366 | _ => panic!(),
367 | }
368 | }
369 |
370 | #[test]
371 | fn router_merge_test() {
372 | let mut main_router = Router::new();
373 | let mut sub_router = Router::new();
374 |
375 | main_router.get("router/test", handler);
376 | sub_router.get("router/test", handler);
377 |
378 | let logger = Logger::new();
379 |
380 | sub_router.use_service(logger);
381 |
382 | main_router.use_router("sub_router", sub_router);
383 |
384 | let result = main_router.search_route("router/test");
385 | let sub_result = main_router.search_route("sub_router/router/test");
386 | let fail_result = main_router.search_route("failed");
387 |
388 | assert!(result.is_some());
389 | assert!(sub_result.is_some());
390 | assert!(fail_result.is_none());
391 |
392 | match result {
393 | Some(route) => {
394 | let middlewares = route.get_middlewares();
395 | let route_value = route.get_route(&Method::GET).unwrap();
396 |
397 | assert_eq!(middlewares.len(), 0);
398 | assert_eq!(route_value.method, Method::GET);
399 | }
400 | _ => panic!(),
401 | }
402 |
403 | match sub_result {
404 | Some(route) => {
405 | let middlewares = route.get_middlewares();
406 | let route_value = route.get_route(&Method::GET).unwrap();
407 |
408 | assert_eq!(middlewares.len(), 1);
409 | assert_eq!(route_value.method, Method::GET);
410 | }
411 | _ => panic!(),
412 | }
413 | }
414 |
415 | #[should_panic]
416 | #[test]
417 | fn router_duplicate_path_test() {
418 | let mut router = Router::new();
419 |
420 | router.get("router/test", handler);
421 | router.get("router/test", handler);
422 | }
423 |
424 | #[should_panic]
425 | #[test]
426 | fn router_ambiguous_path_test() {
427 | let mut router = Router::new();
428 |
429 | router.get("router/:test", handler);
430 | router.get("router/test", handler);
431 | }
432 |
433 | #[should_panic]
434 | #[test]
435 | fn router_duplicate_merge_test() {
436 | let mut main_router = Router::new();
437 | let mut sub_router = Router::new();
438 |
439 | main_router.get("sub_router/test", handler);
440 | sub_router.get("test", handler);
441 |
442 | let logger = Logger::new();
443 |
444 | sub_router.use_service(logger);
445 |
446 | main_router.use_router("sub_router", sub_router);
447 | }
448 | }
449 |
--------------------------------------------------------------------------------
/src/router/handler.rs:
--------------------------------------------------------------------------------
1 | use crate::context::Context;
2 | use crate::error::ObsidianError;
3 |
4 | use async_trait::async_trait;
5 | use std::future::Future;
6 |
7 | pub type ContextResult = Result;
8 |
9 | #[async_trait]
10 | pub trait Handler: Send + Sync + 'static {
11 | async fn call(&self, ctx: Context) -> ContextResult;
12 | }
13 |
14 | #[async_trait]
15 | impl Handler for T
16 | where
17 | T: Fn(Context) -> F + Send + Sync + 'static,
18 | F: Future + Send + 'static,
19 | {
20 | async fn call(&self, ctx: Context) -> ContextResult {
21 | (self)(ctx).await
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/src/router/req_deserializer.rs:
--------------------------------------------------------------------------------
1 | use serde::de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor};
2 | use serde::forward_to_deserialize_any;
3 | use serde::ser;
4 | use serde::Deserialize;
5 |
6 | use std::borrow::Cow;
7 | use std::collections::HashMap;
8 | use std::fmt;
9 | use std::fmt::Display;
10 |
11 | /// Parse merged forms key, value pair get from form_urlencoded into a user defined struct
12 | /// Key and Value should be in Cow pointer
13 | ///
14 | /// # Example
15 | ///
16 | /// ```
17 | /// # use obsidian::router::from_cow_map;
18 | /// # use hyper::{Body, Request, body, body::Buf};
19 | /// # use url::form_urlencoded;
20 | /// # use serde::*;
21 | /// # use std::collections::HashMap;
22 | /// # use std::borrow::Cow;
23 | /// # use async_std::task;
24 | ///
25 | /// #[derive(Deserialize, Debug, PartialEq)]
26 | /// struct Example {
27 | /// field1: Vec,
28 | /// field2: i32,
29 | /// }
30 | /// task::block_on(
31 | /// async {
32 | /// let body = Request::new(Body::from("field1=1&field1=2&field2=12")).into_body();
33 | ///
34 | /// let buf = match body::aggregate(body).await {
35 | /// Ok(buf) => buf,
36 | /// Err(e) => {
37 | /// println!("{}", e);
38 | /// panic!()
39 | /// }
40 | /// };
41 | ///
42 | /// let mut parsed_form_map: HashMap> = HashMap::default();
43 | /// let mut cow_form_map = HashMap::, Cow<[String]>>::default();
44 | ///
45 | /// // Parse and merge chunks with same name key
46 | /// form_urlencoded::parse(buf.chunk())
47 | /// .into_owned()
48 | /// .for_each(|(key, val)| {
49 | /// parsed_form_map.entry(key).or_insert(vec![]).push(val);
50 | /// });
51 | ///
52 | /// // Wrap vec with cow pointer
53 | /// parsed_form_map.iter().for_each(|(key, val)| {
54 | /// cow_form_map
55 | /// .entry(std::borrow::Cow::from(key))
56 | /// .or_insert(std::borrow::Cow::from(val));
57 | /// });
58 | ///
59 | /// let actual_result: Example = from_cow_map(&cow_form_map).unwrap();
60 | /// let expected_result = Example{field1: vec![1,2], field2:12};
61 | ///
62 | /// assert_eq!(actual_result, expected_result);
63 | /// })
64 | /// ```
65 | pub fn from_cow_map<'de, T, S: ::std::hash::BuildHasher>(
66 | s: &'de HashMap, Cow<'de, [String]>, S>,
67 | ) -> Result
68 | where
69 | T: Deserialize<'de>,
70 | {
71 | let mut deserializer = FormDeserializer::from_cow_map(s.iter().peekable());
72 | let t = T::deserialize(&mut deserializer)?;
73 | Ok(t)
74 | }
75 |
76 | /// Deserializer for merged hashmap forms.
77 | struct FormDeserializer<'de> {
78 | input: std::iter::Peekable<
79 | std::collections::hash_map::Iter<
80 | 'de,
81 | std::borrow::Cow<'de, str>,
82 | std::borrow::Cow<'de, [String]>,
83 | >,
84 | >,
85 | }
86 |
87 | macro_rules! from_string_forms_key_impl {
88 | ($($t:ty => $method:ident)*) => {$(
89 | fn $method(self, visitor: V) -> Result
90 | where V: de::Visitor<'de>
91 | {
92 | match self.input.peek() {
93 | Some(key) => key.0.clone().into_deserializer().$method(visitor),
94 | _ => Err(Error::NoneError),
95 | }
96 | }
97 | )*}
98 | }
99 |
100 | impl<'de> FormDeserializer<'de> {
101 | pub fn from_cow_map(
102 | input: std::iter::Peekable<
103 | std::collections::hash_map::Iter<
104 | 'de,
105 | std::borrow::Cow<'de, str>,
106 | std::borrow::Cow<'de, [String]>,
107 | >,
108 | >,
109 | ) -> Self {
110 | FormDeserializer { input }
111 | }
112 | }
113 |
114 | impl<'de, 'a> de::Deserializer<'de> for &'a mut FormDeserializer<'de> {
115 | type Error = Error;
116 |
117 | fn deserialize_any(self, visitor: V) -> Result
118 | where
119 | V: Visitor<'de>,
120 | {
121 | self.deserialize_map(visitor)
122 | }
123 |
124 | fn deserialize_map(self, visitor: V) -> Result
125 | where
126 | V: Visitor<'de>,
127 | {
128 | visitor.visit_map(FromMap::new(self))
129 | }
130 |
131 | fn deserialize_string(self, visitor: V) -> Result
132 | where
133 | V: Visitor<'de>,
134 | {
135 | self.deserialize_str(visitor)
136 | }
137 |
138 | fn deserialize_str(self, visitor: V) -> Result
139 | where
140 | V: Visitor<'de>,
141 | {
142 | match self.input.peek() {
143 | Some(key) => visitor.visit_str(key.0),
144 | _ => Err(Error::NoneError),
145 | }
146 | }
147 |
148 | fn deserialize_identifier(self, visitor: V) -> Result
149 | where
150 | V: Visitor<'de>,
151 | {
152 | self.deserialize_str(visitor)
153 | }
154 |
155 | forward_to_deserialize_any! {
156 | char
157 | option
158 | bytes
159 | byte_buf
160 | unit_struct
161 | newtype_struct
162 | tuple_struct
163 | struct
164 | tuple
165 | enum
166 | ignored_any
167 | unit
168 | seq
169 | }
170 |
171 | from_string_forms_key_impl! {
172 | bool => deserialize_bool
173 | u8 => deserialize_u8
174 | u16 => deserialize_u16
175 | u32 => deserialize_u32
176 | u64 => deserialize_u64
177 | i8 => deserialize_i8
178 | i16 => deserialize_i16
179 | i32 => deserialize_i32
180 | i64 => deserialize_i64
181 | f32 => deserialize_f32
182 | f64 => deserialize_f64
183 | }
184 | }
185 |
186 | macro_rules! from_string_forms_impl {
187 | ($($t:ty => $method:ident)*) => {$(
188 | fn $method(self, visitor: V) -> Result
189 | where V: de::Visitor<'de>
190 | {
191 | match self.input[0].parse::<$t>() {
192 | Ok(result) => result.into_deserializer().$method(visitor),
193 | Err(e) => Err(Error::Message(format!("{}", e))),
194 | }
195 | }
196 | )*}
197 | }
198 |
199 | // Deserializer for value of merged hashmap forms.
200 | struct FormValueDeserializer<'de> {
201 | input: &'de [std::string::String],
202 | }
203 |
204 | impl<'de> FormValueDeserializer<'de> {
205 | pub fn new(input: &'de [std::string::String]) -> Self {
206 | FormValueDeserializer { input }
207 | }
208 | }
209 |
210 | impl<'de> de::Deserializer<'de> for &mut FormValueDeserializer<'de> {
211 | type Error = Error;
212 |
213 | fn deserialize_any(self, visitor: V) -> Result
214 | where
215 | V: Visitor<'de>,
216 | {
217 | self.deserialize_string(visitor)
218 | }
219 |
220 | fn deserialize_seq(self, visitor: V) -> Result
221 | where
222 | V: Visitor<'de>,
223 | {
224 | visitor.visit_seq(FromSeq::new(self))
225 | }
226 |
227 | fn deserialize_string(self, visitor: V) -> Result
228 | where
229 | V: Visitor<'de>,
230 | {
231 | visitor.visit_string(self.input[0].clone())
232 | }
233 |
234 | fn deserialize_str(self, visitor: V) -> Result
235 | where
236 | V: Visitor<'de>,
237 | {
238 | visitor.visit_str(&self.input[0])
239 | }
240 |
241 | fn deserialize_char(self, visitor: V) -> Result
242 | where
243 | V: Visitor<'de>,
244 | {
245 | match self.input[0].chars().next() {
246 | Some(val) => visitor.visit_char(val),
247 | _ => Err(Error::NoneError),
248 | }
249 | }
250 |
251 | fn deserialize_option(self, visitor: V) -> Result
252 | where
253 | V: Visitor<'de>,
254 | {
255 | if self.input.starts_with(&[String::default()]) {
256 | self.input = &self.input[1..];
257 | visitor.visit_none()
258 | } else {
259 | visitor.visit_some(self)
260 | }
261 | }
262 |
263 | fn deserialize_unit(self, visitor: V) -> Result
264 | where
265 | V: Visitor<'de>,
266 | {
267 | if self.input.starts_with(&[String::default()]) {
268 | self.input = &self.input[1..];
269 | }
270 |
271 | visitor.visit_unit()
272 | }
273 |
274 | fn deserialize_unit_struct(self, _name: &'static str, visitor: V) -> Result
275 | where
276 | V: Visitor<'de>,
277 | {
278 | self.deserialize_unit(visitor)
279 | }
280 |
281 | fn deserialize_newtype_struct(
282 | self,
283 | _name: &'static str,
284 | visitor: V,
285 | ) -> Result
286 | where
287 | V: Visitor<'de>,
288 | {
289 | visitor.visit_newtype_struct(self)
290 | }
291 |
292 | fn deserialize_tuple(self, _len: usize, visitor: V) -> Result
293 | where
294 | V: Visitor<'de>,
295 | {
296 | self.deserialize_seq(visitor)
297 | }
298 |
299 | fn deserialize_identifier(self, visitor: V) -> Result
300 | where
301 | V: Visitor<'de>,
302 | {
303 | self.deserialize_str(visitor)
304 | }
305 |
306 | forward_to_deserialize_any! {
307 | bytes
308 | byte_buf
309 | ignored_any
310 | map
311 | struct
312 | tuple_struct
313 | enum
314 | }
315 |
316 | from_string_forms_impl! {
317 | bool => deserialize_bool
318 | u8 => deserialize_u8
319 | u16 => deserialize_u16
320 | u32 => deserialize_u32
321 | u64 => deserialize_u64
322 | i8 => deserialize_i8
323 | i16 => deserialize_i16
324 | i32 => deserialize_i32
325 | i64 => deserialize_i64
326 | f32 => deserialize_f32
327 | f64 => deserialize_f64
328 | }
329 | }
330 |
331 | struct FromSeq<'a, 'de: 'a> {
332 | de: &'a mut FormValueDeserializer<'de>,
333 | first: bool,
334 | }
335 |
336 | impl<'a, 'de> FromSeq<'a, 'de> {
337 | fn new(de: &'a mut FormValueDeserializer<'de>) -> Self {
338 | FromSeq { de, first: true }
339 | }
340 | }
341 |
342 | impl<'de, 'a> SeqAccess<'de> for FromSeq<'a, 'de> {
343 | type Error = Error;
344 |
345 | fn next_element_seed(&mut self, seed: T) -> Result, Error>
346 | where
347 | T: DeserializeSeed<'de>,
348 | {
349 | if self.de.input.len() == 1 && !self.first {
350 | return Ok(None);
351 | }
352 |
353 | // Only start moving slices after processing
354 | if !self.first {
355 | self.de.input = &self.de.input[1..];
356 | }
357 | self.first = false;
358 |
359 | seed.deserialize(&mut *self.de).map(Some)
360 | }
361 | }
362 |
363 | struct FromMap<'a, 'de: 'a> {
364 | de: &'a mut FormDeserializer<'de>,
365 | first: bool,
366 | }
367 |
368 | impl<'a, 'de> FromMap<'a, 'de> {
369 | fn new(de: &'a mut FormDeserializer<'de>) -> Self {
370 | FromMap { de, first: true }
371 | }
372 | }
373 |
374 | impl<'de, 'a> MapAccess<'de> for FromMap<'a, 'de> {
375 | type Error = Error;
376 |
377 | fn next_key_seed(&mut self, seed: K) -> Result, Error>
378 | where
379 | K: DeserializeSeed<'de>,
380 | {
381 | if !self.first {
382 | self.de.input.next();
383 | }
384 |
385 | self.first = false;
386 |
387 | match self.de.input.peek() {
388 | Some(_x) => seed.deserialize(&mut *self.de).map(Some),
389 | _ => Ok(None),
390 | }
391 | }
392 |
393 | fn next_value_seed(&mut self, seed: V) -> Result
394 | where
395 | V: DeserializeSeed<'de>,
396 | {
397 | match self.de.input.peek() {
398 | Some(val) => seed.deserialize(&mut FormValueDeserializer::new(val.1)),
399 | _ => Err(Error::NoneError),
400 | }
401 | }
402 | }
403 |
404 | /// Error for request deserializer
405 | #[derive(Clone, Debug, PartialEq)]
406 | pub enum Error {
407 | Message(String),
408 | NoneError,
409 | }
410 |
411 | impl ser::Error for Error {
412 | fn custom(msg: T) -> Self {
413 | Error::Message(msg.to_string())
414 | }
415 | }
416 |
417 | impl de::Error for Error {
418 | fn custom(msg: T) -> Self {
419 | Error::Message(msg.to_string())
420 | }
421 | }
422 |
423 | impl Display for Error {
424 | fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
425 | match *self {
426 | Error::Message(ref msg) => formatter.write_str(msg),
427 | Error::NoneError => formatter.write_str("Input should not be None"),
428 | }
429 | }
430 | }
431 |
432 | impl std::error::Error for Error {}
433 |
434 | #[cfg(test)]
435 | mod tests {
436 | use super::*;
437 | use async_std::task;
438 | use hyper::{body, body::Buf, Body, Request};
439 | use url::form_urlencoded;
440 |
441 | #[derive(Deserialize, Debug, PartialEq)]
442 | struct VecAndSingleVariableStruct {
443 | field1: Vec,
444 | field2: i32,
445 | }
446 |
447 | #[derive(Deserialize, Debug, PartialEq)]
448 | struct VecStruct {
449 | field1: Vec,
450 | }
451 |
452 | #[derive(Deserialize, Debug, PartialEq)]
453 | struct VecWithDefaultStruct {
454 | field1: Vec,
455 | #[serde(default)]
456 | field2: i32,
457 | }
458 |
459 | #[test]
460 | fn test_deserialize_to_struct_with_vec_and_single_variable() {
461 | task::block_on(async {
462 | let body = Request::new(Body::from("field1=abc&field1=xyz&field2=12")).into_body();
463 | let buf = match body::aggregate(body).await {
464 | Ok(buf) => buf,
465 | Err(e) => {
466 | panic!("Body parsing fail {}", e);
467 | }
468 | };
469 | let mut parsed_form_map: HashMap> = HashMap::default();
470 | let mut cow_form_map = HashMap::, Cow<[String]>>::default();
471 | // Parse and merge chunks with same name key
472 | form_urlencoded::parse(buf.chunk())
473 | .into_owned()
474 | .for_each(|(key, val)| {
475 | parsed_form_map
476 | .entry(key)
477 | .or_insert_with(Vec::new)
478 | .push(val);
479 | });
480 | // Wrap vec with cow pointer
481 | parsed_form_map.iter().for_each(|(key, val)| {
482 | cow_form_map
483 | .entry(std::borrow::Cow::from(key))
484 | .or_insert_with(|| std::borrow::Cow::from(val));
485 | });
486 |
487 | let actual_result: VecAndSingleVariableStruct = from_cow_map(&cow_form_map).unwrap();
488 | let expected_result = VecAndSingleVariableStruct {
489 | field1: vec!["abc".to_string(), "xyz".to_string()],
490 | field2: 12,
491 | };
492 | assert_eq!(actual_result, expected_result);
493 | })
494 | }
495 |
496 | #[test]
497 | fn test_deserialize_to_struct_with_vec() {
498 | task::block_on(async {
499 | let body = Request::new(Body::from("field1=1&field1=2")).into_body();
500 | let buf = match body::aggregate(body).await {
501 | Ok(buf) => buf,
502 | Err(e) => {
503 | panic!("Body parsing fail {}", e);
504 | }
505 | };
506 | let mut parsed_form_map: HashMap> = HashMap::default();
507 | let mut cow_form_map = HashMap::, Cow<[String]>>::default();
508 | // Parse and merge chunks with same name key
509 | form_urlencoded::parse(buf.chunk())
510 | .into_owned()
511 | .for_each(|(key, val)| {
512 | parsed_form_map
513 | .entry(key)
514 | .or_insert_with(Vec::new)
515 | .push(val);
516 | });
517 | // Wrap vec with cow pointer
518 | parsed_form_map.iter().for_each(|(key, val)| {
519 | cow_form_map
520 | .entry(std::borrow::Cow::from(key))
521 | .or_insert_with(|| std::borrow::Cow::from(val));
522 | });
523 |
524 | let actual_result: VecStruct = from_cow_map(&cow_form_map).unwrap();
525 | let expected_result = VecStruct { field1: vec![1, 2] };
526 | assert_eq!(actual_result, expected_result);
527 | })
528 | }
529 |
530 | #[test]
531 | fn test_deserialize_to_struct_with_extra_form_value() {
532 | task::block_on(async {
533 | let body = Request::new(Body::from("field1=1&field1=2&field2=12")).into_body();
534 | let buf = match body::aggregate(body).await {
535 | Ok(buf) => buf,
536 | Err(e) => {
537 | panic!("Body parsing fail {}", e);
538 | }
539 | };
540 | let mut parsed_form_map: HashMap> = HashMap::default();
541 | let mut cow_form_map = HashMap::, Cow<[String]>>::default();
542 | // Parse and merge chunks with same name key
543 | form_urlencoded::parse(buf.chunk())
544 | .into_owned()
545 | .for_each(|(key, val)| {
546 | parsed_form_map
547 | .entry(key)
548 | .or_insert_with(Vec::new)
549 | .push(val);
550 | });
551 | // Wrap vec with cow pointer
552 | parsed_form_map.iter().for_each(|(key, val)| {
553 | cow_form_map
554 | .entry(std::borrow::Cow::from(key))
555 | .or_insert_with(|| std::borrow::Cow::from(val));
556 | });
557 |
558 | let actual_result: VecStruct = from_cow_map(&cow_form_map).unwrap();
559 | let expected_result = VecStruct { field1: vec![1, 2] };
560 | assert_eq!(actual_result, expected_result);
561 | })
562 | }
563 |
564 | #[test]
565 | fn test_deserialize_to_struct_with_extra_struct_field() {
566 | task::block_on(async {
567 | let body = Request::new(Body::from("field1=1&field1=2")).into_body();
568 | let buf = match body::aggregate(body).await {
569 | Ok(buf) => buf,
570 | Err(e) => {
571 | panic!("Body parsing fail {}", e);
572 | }
573 | };
574 | let mut parsed_form_map: HashMap> = HashMap::default();
575 | let mut cow_form_map = HashMap::, Cow<[String]>>::default();
576 | // Parse and merge chunks with same name key
577 | form_urlencoded::parse(buf.chunk())
578 | .into_owned()
579 | .for_each(|(key, val)| {
580 | parsed_form_map
581 | .entry(key)
582 | .or_insert_with(Vec::new)
583 | .push(val);
584 | });
585 | // Wrap vec with cow pointer
586 | parsed_form_map.iter().for_each(|(key, val)| {
587 | cow_form_map
588 | .entry(std::borrow::Cow::from(key))
589 | .or_insert_with(|| std::borrow::Cow::from(val));
590 | });
591 |
592 | let actual_result: VecWithDefaultStruct = from_cow_map(&cow_form_map).unwrap();
593 | let expected_result = VecWithDefaultStruct {
594 | field1: vec![1, 2],
595 | field2: i32::default(),
596 | };
597 | assert_eq!(actual_result, expected_result);
598 | })
599 | }
600 |
601 | #[test]
602 | fn test_deserialize_to_map_type() {
603 | task::block_on(async {
604 | let body = Request::new(Body::from("field1=1&field1=2&field2=3")).into_body();
605 | let buf = match body::aggregate(body).await {
606 | Ok(buf) => buf,
607 | Err(e) => {
608 | panic!("Body parsing fail {}", e);
609 | }
610 | };
611 | let mut parsed_form_map: HashMap> = HashMap::default();
612 | let mut cow_form_map = HashMap::, Cow<[String]>>::default();
613 | // Parse and merge chunks with same name key
614 | form_urlencoded::parse(buf.chunk())
615 | .into_owned()
616 | .for_each(|(key, val)| {
617 | parsed_form_map
618 | .entry(key)
619 | .or_insert_with(Vec::new)
620 | .push(val);
621 | });
622 | // Wrap vec with cow pointer
623 | parsed_form_map.iter().for_each(|(key, val)| {
624 | cow_form_map
625 | .entry(std::borrow::Cow::from(key))
626 | .or_insert_with(|| std::borrow::Cow::from(val));
627 | });
628 |
629 | let actual_result: HashMap> = from_cow_map(&cow_form_map).unwrap();
630 | let mut expected_result: HashMap> = HashMap::default();
631 |
632 | expected_result
633 | .entry("field1".to_string())
634 | .or_insert_with(Vec::new)
635 | .push(1);
636 | expected_result
637 | .entry("field1".to_string())
638 | .or_insert_with(Vec::new)
639 | .push(2);
640 | expected_result
641 | .entry("field2".to_string())
642 | .or_insert_with(Vec::new)
643 | .push(3);
644 | assert_eq!(actual_result, expected_result);
645 | })
646 | }
647 | }
648 |
--------------------------------------------------------------------------------
/src/router/resource.rs:
--------------------------------------------------------------------------------
1 | use hyper::Method;
2 | use std::collections::HashMap;
3 |
4 | use super::Route;
5 |
6 | /// Resource acts as the intermidiate interface for interaction of routing data structure
7 | /// Resource is binding with the path and handling all of the request method for that path
8 | #[derive(Clone, Debug)]
9 | pub struct Resource {
10 | route_map: HashMap,
11 | }
12 |
13 | impl Default for Resource {
14 | fn default() -> Self {
15 | Resource {
16 | route_map: HashMap::new(),
17 | }
18 | }
19 | }
20 |
21 | impl Resource {
22 | pub fn add_route(&mut self, method: Method, route: Route) -> Option {
23 | self.route_map.insert(method, route)
24 | }
25 |
26 | pub fn get_route(&self, method: &Method) -> Option<&Route> {
27 | self.route_map.get(&method)
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/src/router/responder.rs:
--------------------------------------------------------------------------------
1 | use super::Response;
2 | use super::ResponseBody;
3 | use hyper::{header, StatusCode};
4 |
5 | pub trait Responder {
6 | fn respond_to(self) -> Response;
7 | fn with_status(self, status: StatusCode) -> Response
8 | where
9 | Self: Responder + ResponseBody + Sized,
10 | {
11 | Response::new(self).set_status(status)
12 | }
13 |
14 | fn with_header(self, key: header::HeaderName, value: &'static str) -> Response
15 | where
16 | Self: Responder + ResponseBody + Sized,
17 | {
18 | Response::new(self).set_header(key, value)
19 | }
20 |
21 | fn with_headers(self, headers: Vec<(header::HeaderName, &'static str)>) -> Response
22 | where
23 | Self: Responder + ResponseBody + Sized,
24 | {
25 | Response::new(self).set_headers(headers)
26 | }
27 |
28 | fn with_headers_str(self, headers: Vec<(&'static str, &'static str)>) -> Response
29 | where
30 | Self: Responder + ResponseBody + Sized,
31 | {
32 | Response::new(self).set_headers_str(headers)
33 | }
34 | }
35 |
36 | impl Responder for Response {
37 | fn respond_to(self) -> Response {
38 | self
39 | }
40 | }
41 |
42 | impl Responder for String {
43 | fn respond_to(self) -> Response {
44 | Response::new(self).set_content_type("text/plain; charset=utf-8")
45 | }
46 | }
47 |
48 | impl Responder for &'static str {
49 | fn respond_to(self) -> Response {
50 | self.to_string().respond_to()
51 | }
52 | }
53 |
54 | impl Responder for () {
55 | fn respond_to(self) -> Response {
56 | Response::new(())
57 | }
58 | }
59 |
60 | impl Responder for (StatusCode, T)
61 | where
62 | T: Responder + ResponseBody,
63 | {
64 | fn respond_to(self) -> Response {
65 | let (status_code, body) = self;
66 | Response::new(body).set_status(status_code)
67 | }
68 | }
69 |
70 | impl Responder for Vec {
71 | fn respond_to(self) -> Response {
72 | match serde_json::to_string(&self) {
73 | Ok(json) => json.with_status(StatusCode::OK).respond_to(),
74 | Err(e) => {
75 | eprintln!("serializing failed: {}", e);
76 | let error = e.to_string();
77 | error
78 | .with_status(StatusCode::INTERNAL_SERVER_ERROR)
79 | .respond_to()
80 | }
81 | }
82 | }
83 | }
84 |
85 | impl Responder for StatusCode {
86 | fn respond_to(self) -> Response {
87 | ().with_status(self).respond_to()
88 | }
89 | }
90 |
91 | impl Responder for Option {
92 | fn respond_to(self) -> Response {
93 | match self {
94 | Some(resp) => resp.respond_to(),
95 | None => "Not Found"
96 | .to_string()
97 | .with_status(StatusCode::NOT_FOUND)
98 | .respond_to(),
99 | }
100 | }
101 | }
102 |
103 | impl Responder for Option<&'static str> {
104 | fn respond_to(self) -> Response {
105 | match self {
106 | Some(resp) => resp.respond_to(),
107 | None => "Not Found"
108 | .to_string()
109 | .with_status(StatusCode::NOT_FOUND)
110 | .respond_to(),
111 | }
112 | }
113 | }
114 |
115 | #[cfg(test)]
116 | mod test {
117 | use super::*;
118 | use hyper::StatusCode;
119 |
120 | #[test]
121 | fn test_str_responder() {
122 | let response = "Hello World".respond_to();
123 | assert_eq!(response.status(), StatusCode::OK);
124 | // TODO: add testing for body once the Responder is refactored
125 | }
126 |
127 | #[test]
128 | fn test_string_responder() {
129 | let response = "Hello World".to_string().respond_to();
130 | assert_eq!(response.status(), StatusCode::OK);
131 | // TODO: add testing for body once the Responder is refactored
132 | }
133 |
134 | // #[test]
135 | // fn test_option_responder() {
136 | // let some_result = Some("Hello World").respond_to();
137 | // if let Ok(response) = some_result {
138 | // assert_eq!(response.status(), StatusCode::OK);
139 | // // TODO: add testing for body once the Responder is refactored
140 | // }
141 |
142 | // let none_result = None::.respond_to();
143 | // if let Ok(response) = none_result {
144 | // assert_eq!(response.status(), StatusCode::NOT_FOUND);
145 | // // TODO: add testing for body once the Responder is refactored
146 | // }
147 | // }
148 |
149 | // #[test]
150 | // fn test_result_responder() {
151 | // let ok_result = Ok::<&str, &str>("Hello World").respond_to();
152 | // if let Ok(response) = ok_result {
153 | // assert_eq!(response.status(), StatusCode::OK);
154 | // // TODO: add testing for body once the Responder is refactored
155 | // }
156 |
157 | // let err_result = Err::<&str, &str>("Some error").respond_to();
158 | // if let Ok(response) = err_result {
159 | // assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
160 | // // TODO: add testing for body once the Responder is refactored
161 | // }
162 | // }
163 |
164 | #[test]
165 | fn test_responder_with_custom_status() {
166 | let response = "Test".with_status(StatusCode::CREATED).respond_to();
167 | assert_eq!(response.status(), StatusCode::CREATED);
168 | }
169 |
170 | #[test]
171 | fn test_responder_with_custom_header() {
172 | let response = "Test"
173 | .with_header(header::CONTENT_TYPE, "application/json")
174 | .respond_to();
175 | assert_eq!(response.status(), StatusCode::OK);
176 | assert!(response
177 | .headers()
178 | .as_ref()
179 | .unwrap()
180 | .contains(&(header::CONTENT_TYPE, "application/json")));
181 | }
182 | }
183 |
--------------------------------------------------------------------------------
/src/router/response.rs:
--------------------------------------------------------------------------------
1 | use super::ResponseBody;
2 |
3 | use async_std::fs;
4 | use http::StatusCode;
5 | use hyper::{header, Body};
6 | use serde::ser::Serialize;
7 |
8 | #[derive(Debug)]
9 | pub struct Response {
10 | body: Body,
11 | status: StatusCode,
12 | headers: Option>,
13 | }
14 |
15 | impl Response {
16 | pub fn new(body: impl ResponseBody) -> Self {
17 | Response {
18 | body: body.into_body(),
19 | status: StatusCode::OK,
20 | headers: None,
21 | }
22 | }
23 |
24 | pub fn status(&self) -> StatusCode {
25 | self.status
26 | }
27 |
28 | pub fn status_mut(&mut self) -> &mut StatusCode {
29 | &mut self.status
30 | }
31 |
32 | pub fn body(self) -> Body {
33 | self.body
34 | }
35 |
36 | pub fn headers(&self) -> &Option> {
37 | &self.headers
38 | }
39 |
40 | pub fn headers_mut(&mut self) -> &mut Option> {
41 | &mut self.headers
42 | }
43 |
44 | pub fn with_status(self, status: StatusCode) -> Self {
45 | self.set_status(status)
46 | }
47 |
48 | pub fn set_status(mut self, status: http::StatusCode) -> Self {
49 | self.status = status;
50 | self
51 | }
52 |
53 | pub fn set_body(mut self, body: impl ResponseBody) -> Self {
54 | self.body = body.into_body();
55 | self
56 | }
57 |
58 | pub fn set_header(mut self, key: header::HeaderName, value: &'static str) -> Self {
59 | match self.headers {
60 | Some(ref mut x) => x.push((key, value)),
61 | None => self.headers = Some(vec![(key, value)]),
62 | };
63 | self
64 | }
65 |
66 | // Alias set_header method
67 | pub fn with_header(self, key: header::HeaderName, value: &'static str) -> Self {
68 | self.set_header(key, value)
69 | }
70 |
71 | pub fn set_header_str(self, key: &'static str, value: &'static str) -> Self {
72 | self.set_header(
73 | header::HeaderName::from_bytes(key.as_bytes()).unwrap(),
74 | value,
75 | )
76 | }
77 |
78 | // Alias set_header_str method
79 | pub fn with_header_str(self, key: &'static str, value: &'static str) -> Self {
80 | self.set_header_str(key, value)
81 | }
82 |
83 | pub fn set_content_type(self, content_type: &'static str) -> Self {
84 | self.set_header(header::CONTENT_TYPE, content_type)
85 | }
86 |
87 | pub fn set_headers(mut self, headers: Vec<(header::HeaderName, &'static str)>) -> Self {
88 | match self.headers {
89 | Some(ref mut x) => x.extend_from_slice(&headers),
90 | None => self.headers = Some(headers),
91 | };
92 | self
93 | }
94 |
95 | // Alias set_headers method
96 | pub fn with_headers(self, headers: Vec<(header::HeaderName, &'static str)>) -> Self {
97 | self.set_headers(headers)
98 | }
99 |
100 | pub fn set_headers_str(mut self, headers: Vec<(&'static str, &'static str)>) -> Self {
101 | let values: Vec<(header::HeaderName, &'static str)> = headers
102 | .iter()
103 | .map(|(k, v)| (header::HeaderName::from_bytes(k.as_bytes()).unwrap(), *v))
104 | .collect();
105 |
106 | match self.headers {
107 | Some(ref mut x) => x.extend_from_slice(&values),
108 | None => self.headers = Some(values),
109 | };
110 | self
111 | }
112 |
113 | // Alias set_headers_str method
114 | pub fn with_headers_str(self, headers: Vec<(&'static str, &'static str)>) -> Self {
115 | self.set_headers_str(headers)
116 | }
117 |
118 | pub fn html(self, body: impl ResponseBody) -> Self {
119 | self.set_content_type("text/html").set_body(body)
120 | }
121 |
122 | pub fn json(self, body: impl Serialize) -> Self {
123 | match serde_json::to_string(&body) {
124 | Ok(val) => self.set_content_type("application/json").set_body(val),
125 | Err(err) => self
126 | .set_content_type("application/json")
127 | .set_body(err.to_string())
128 | .set_status(StatusCode::INTERNAL_SERVER_ERROR),
129 | }
130 | }
131 |
132 | pub async fn file(self, file_path: &str) -> Self {
133 | match fs::read_to_string(file_path.to_string()).await {
134 | Ok(content) => self.set_body(content),
135 | Err(err) => {
136 | dbg!(&err);
137 | self.set_body(err.to_string())
138 | .set_status(StatusCode::NOT_FOUND)
139 | }
140 | }
141 | }
142 |
143 | // Utilities
144 | pub fn ok() -> Self {
145 | Response::new(()).with_status(StatusCode::OK)
146 | }
147 |
148 | pub fn created() -> Self {
149 | Response::new(()).with_status(StatusCode::CREATED)
150 | }
151 |
152 | pub fn internal_server_error() -> Self {
153 | Response::new(()).with_status(StatusCode::INTERNAL_SERVER_ERROR)
154 | }
155 | }
156 |
157 | #[cfg(test)]
158 | mod test {
159 | use super::*;
160 | use hyper::StatusCode;
161 | use serde::*;
162 |
163 | #[test]
164 | fn test_response() {
165 | let response = Response::new("Hello World");
166 | assert_eq!(response.status(), StatusCode::OK);
167 | // TODO: add testing for body once the Responder is refactored
168 | }
169 |
170 | #[test]
171 | fn test_response_utilities_status() {
172 | assert_eq!(Response::ok().status(), StatusCode::OK);
173 | assert_eq!(Response::created().status(), StatusCode::CREATED);
174 | assert_eq!(
175 | Response::internal_server_error().status(),
176 | StatusCode::INTERNAL_SERVER_ERROR
177 | );
178 | }
179 |
180 | #[test]
181 | fn test_complete_response() {
182 | #[derive(Serialize, Deserialize, Debug)]
183 | struct Person {
184 | name: String,
185 | age: i8,
186 | }
187 |
188 | let person = Person {
189 | name: String::from("Jun Kai"),
190 | age: 27,
191 | };
192 | let response = Response::created()
193 | .set_header(header::AUTHORIZATION, "token")
194 | .json(person);
195 |
196 | assert_eq!(response.status(), StatusCode::CREATED);
197 | assert!(response
198 | .headers()
199 | .as_ref()
200 | .unwrap()
201 | .contains(&(header::CONTENT_TYPE, "application/json")));
202 | assert!(response
203 | .headers()
204 | .as_ref()
205 | .unwrap()
206 | .contains(&(header::AUTHORIZATION, "token")));
207 | }
208 | }
209 |
--------------------------------------------------------------------------------
/src/router/response_body.rs:
--------------------------------------------------------------------------------
1 | use hyper::Body;
2 |
3 | pub trait ResponseBody {
4 | fn into_body(self) -> Body;
5 | }
6 |
7 | impl ResponseBody for () {
8 | fn into_body(self) -> Body {
9 | Body::empty()
10 | }
11 | }
12 |
13 | impl ResponseBody for &'static str {
14 | fn into_body(self) -> Body {
15 | Body::from(self)
16 | }
17 | }
18 |
19 | impl ResponseBody for String {
20 | fn into_body(self) -> Body {
21 | Body::from(self)
22 | }
23 | }
24 |
25 | impl ResponseBody for Vec {
26 | fn into_body(self) -> Body {
27 | match serde_json::to_string(&self) {
28 | Ok(json) => Body::from(json),
29 | Err(e) => {
30 | eprintln!("serializing failed: {}", e);
31 | Body::from(e.to_string())
32 | }
33 | }
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/src/router/route.rs:
--------------------------------------------------------------------------------
1 | use std::sync::Arc;
2 |
3 | use super::Handler;
4 | use crate::Method;
5 |
6 | pub struct Route {
7 | pub method: Method,
8 | pub handler: Arc,
9 | }
10 |
11 | impl std::fmt::Debug for Route {
12 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13 | write!(f, "Route {{ method: {} }}", self.method)
14 | }
15 | }
16 |
17 | impl Clone for Route {
18 | fn clone(&self) -> Route {
19 | Route {
20 | method: self.method.clone(),
21 | handler: self.handler.clone(),
22 | }
23 | }
24 | }
25 |
26 | impl Route {
27 | pub fn new(method: Method, handler: impl Handler) -> Self {
28 | Route {
29 | method,
30 | handler: Arc::new(handler),
31 | }
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/src/router/route_trie.rs:
--------------------------------------------------------------------------------
1 | use std::collections::HashMap;
2 | use std::fmt;
3 | use std::sync::Arc;
4 |
5 | use hyper::Method;
6 |
7 | use crate::middleware::Middleware;
8 | use crate::router::Resource;
9 | use crate::router::Route;
10 | use crate::ObsidianError;
11 |
12 | #[derive(Clone, Default)]
13 | pub struct RouteValue {
14 | middlewares: Vec>,
15 | route: Resource,
16 | }
17 |
18 | impl fmt::Debug for RouteValue {
19 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 | write!(f, "")
21 | }
22 | }
23 |
24 | impl RouteValue {
25 | pub fn new(middlewares: Vec>, route: Resource) -> Self {
26 | RouteValue { middlewares, route }
27 | }
28 | }
29 |
30 | pub struct RouteValueResult {
31 | route_value: RouteValue,
32 | params: HashMap,
33 | }
34 |
35 | impl RouteValueResult {
36 | pub fn new(route_value: RouteValue, params: HashMap) -> Self {
37 | RouteValueResult {
38 | route_value,
39 | params,
40 | }
41 | }
42 |
43 | pub fn get_route(&self, method: &Method) -> Option<&Route> {
44 | self.route_value.route.get_route(method)
45 | }
46 |
47 | pub fn get_middlewares(&self) -> &Vec> {
48 | &self.route_value.middlewares
49 | }
50 |
51 | pub fn get_params(&self) -> HashMap {
52 | self.params.clone()
53 | }
54 | }
55 |
56 | #[derive(Clone, Debug)]
57 | pub struct RouteTrie {
58 | head: Node,
59 | }
60 |
61 | impl RouteTrie {
62 | pub fn new() -> Self {
63 | RouteTrie {
64 | head: Node::new("/".to_string(), None),
65 | }
66 | }
67 |
68 | /// Insert middleware into root node
69 | pub fn insert_default_middleware(&mut self, middleware: impl Middleware) {
70 | match &mut self.head.value {
71 | Some(val) => {
72 | val.middlewares.push(Arc::new(middleware));
73 | }
74 | None => {
75 | let mut val = RouteValue::default();
76 | val.middlewares.push(Arc::new(middleware));
77 |
78 | self.head.value = Some(val);
79 | }
80 | }
81 | }
82 |
83 | /// Insert route values into the trie
84 | /// Panic if ambigous definition is detected
85 | pub fn insert_route(&mut self, path: &str, route: Route) {
86 | // Split path string and drop additional '/'
87 | let mut split_key = path.split('/').filter(|key| !key.is_empty()).peekable();
88 |
89 | split_key.clone().enumerate().for_each(|(pos, x)| {
90 | if x.contains('*') {
91 | if x.len() != 1 {
92 | panic!("ERROR: Consisting * in route name at: {}", path);
93 | } else if pos != split_key.clone().count() - 1 {
94 | panic!("ERROR: * must be in the last at: {}", path);
95 | }
96 | }
97 | });
98 |
99 | let mut curr_node = &mut self.head;
100 |
101 | // if the path is "/"
102 | if split_key.peek().is_none() {
103 | self.insert_default_route(route);
104 | return;
105 | }
106 |
107 | while let Some(k) = split_key.next() {
108 | match curr_node.process_insertion(k) {
109 | Ok(next_node) => {
110 | if split_key.peek().is_none() {
111 | match &mut next_node.value {
112 | Some(val) => {
113 | if let Some(duplicated) =
114 | val.route.add_route(route.method.clone(), route)
115 | {
116 | panic!(
117 | "Duplicated route method '{}' at '{}' detected",
118 | duplicated.method, path
119 | );
120 | }
121 | }
122 | None => {
123 | let mut next_node_val = RouteValue::default();
124 | if let Some(duplicated) =
125 | next_node_val.route.add_route(route.method.clone(), route)
126 | {
127 | panic!(
128 | "Duplicated route method '{}' at '{}' detected",
129 | duplicated.method, path
130 | );
131 | }
132 |
133 | next_node.value = Some(next_node_val);
134 | }
135 | }
136 | break;
137 | }
138 | curr_node = next_node;
139 | }
140 | Err(err) => {
141 | panic!("Insert Route: {} at {}", err, path);
142 | }
143 | }
144 | }
145 | }
146 |
147 | /// Insert middleware into specific node
148 | pub fn insert_middleware(&mut self, path: &str, middleware: impl Middleware) {
149 | // Split key and drop additional '/'
150 | let split_key = path.split('/');
151 | let mut split_key = split_key.filter(|key| !key.is_empty()).peekable();
152 |
153 | split_key.clone().enumerate().for_each(|(pos, key)| {
154 | if key.contains('*') {
155 | if key.len() != 1 {
156 | panic!("ERROR: Consisting * in route name at: {}", path);
157 | } else if pos != split_key.clone().count() - 1 {
158 | panic!("ERROR: * must be in the last at: {}", path);
159 | }
160 | }
161 | });
162 |
163 | let mut curr_node = &mut self.head;
164 |
165 | while let Some(k) = split_key.next() {
166 | match curr_node.process_insertion(k) {
167 | Ok(next_node) => {
168 | if split_key.peek().is_none() {
169 | match &mut next_node.value {
170 | Some(val) => {
171 | val.middlewares.push(Arc::new(middleware));
172 | }
173 | None => {
174 | let mut next_node_val = RouteValue::default();
175 | next_node_val.middlewares.push(Arc::new(middleware));
176 |
177 | next_node.value = Some(next_node_val);
178 | }
179 | }
180 | break;
181 | }
182 | curr_node = next_node;
183 | }
184 | Err(err) => {
185 | panic!("Middleware: {} at {}", err, path);
186 | }
187 | }
188 | }
189 | }
190 |
191 | /// Search node through the provided key
192 | /// Middleware will be accumulated throughout the search path
193 | pub fn search_route(&self, path: &str) -> Option {
194 | // Split key and drop additional '/'
195 | let split_key = path.split('/');
196 | let mut split_key = split_key
197 | .filter(|key| !key.is_empty())
198 | .collect::>();
199 |
200 | let mut curr_node = &self.head;
201 | let mut params = HashMap::default();
202 | let mut middlewares = vec![];
203 |
204 | match &curr_node.value {
205 | Some(val) => {
206 | middlewares.append(&mut val.middlewares.clone());
207 | }
208 | None => {}
209 | }
210 |
211 | if !split_key.is_empty() {
212 | match curr_node.get_next_node(&mut split_key, &mut params, &mut middlewares, false) {
213 | Some(handler_node) => {
214 | curr_node = handler_node;
215 | }
216 | None => {
217 | // Path is not registered
218 | return None;
219 | }
220 | }
221 | }
222 |
223 | match &curr_node.value {
224 | Some(val) => {
225 | let route_val = RouteValue::new(middlewares, val.route.clone());
226 |
227 | Some(RouteValueResult::new(route_val, params))
228 | }
229 | None => None,
230 | }
231 | }
232 |
233 | /// Insert src trie into the des as a child trie
234 | /// src will be under the node of des with the key path
235 | ///
236 | /// For example, /src/ -> /des/ with 'example' key path
237 | /// src will be located at /des/example/src/
238 | pub fn insert_sub_route(des: &mut Self, path: &str, src: Self) {
239 | // Split key and drop additional '/'
240 | let split_key = path.split('/');
241 | let mut split_key = split_key.filter(|key| !key.is_empty()).peekable();
242 |
243 | split_key.clone().enumerate().for_each(|(pos, x)| {
244 | if x.contains('*') {
245 | if x.len() != 1 {
246 | panic!("ERROR: Consisting * in route name at: {}", path);
247 | } else if pos != split_key.clone().count() - 1 {
248 | panic!("ERROR: * must be in the last at: {}", path);
249 | }
250 | }
251 | });
252 |
253 | let mut curr_node = &mut des.head;
254 |
255 | if split_key.peek().is_none() {
256 | des.head = src.head;
257 | return;
258 | }
259 |
260 | while let Some(k) = split_key.next() {
261 | match curr_node.process_insertion(k) {
262 | Ok(next_node) => {
263 | if split_key.peek().is_none() {
264 | if next_node.value.is_some() || !next_node.child_nodes.is_empty() {
265 | panic!("There is conflict between main router and sub router at '{}'. Make sure main router does not consist any routing data in '{}'.", path, path);
266 | }
267 |
268 | next_node.value = src.head.value;
269 | next_node.child_nodes = src.head.child_nodes;
270 | break;
271 | }
272 | curr_node = next_node;
273 | }
274 | Err(err) => {
275 | panic!("SubRouter: {} at {}", err, path);
276 | }
277 | }
278 | }
279 | }
280 |
281 | fn insert_default_route(&mut self, route: Route) {
282 | match &mut self.head.value {
283 | Some(val) => {
284 | if let Some(duplicated) = val.route.add_route(route.method.clone(), route) {
285 | panic!(
286 | "Duplicated route method '{}' at '/' detected",
287 | duplicated.method
288 | );
289 | }
290 | }
291 | None => {
292 | let mut val = RouteValue::default();
293 | if let Some(duplicated) = val.route.add_route(route.method.clone(), route) {
294 | panic!(
295 | "Duplicated route method '{}' at '/' detected",
296 | duplicated.method
297 | );
298 | }
299 |
300 | self.head.value = Some(val);
301 | }
302 | }
303 | }
304 | }
305 |
306 | #[derive(Clone, Debug)]
307 | struct Node {
308 | key: String,
309 | value: Option,
310 | child_nodes: Vec,
311 | }
312 |
313 | impl Node {
314 | fn new(key: String, value: Option) -> Self {
315 | Node {
316 | key,
317 | value,
318 | child_nodes: Vec::default(),
319 | }
320 | }
321 |
322 | fn is_param(&self) -> bool {
323 | self.key.chars().next().unwrap_or(' ') == ':'
324 | }
325 |
326 | /// Process the side effects of node insertion
327 | fn process_insertion(&mut self, key: &str) -> Result<&mut Self, ObsidianError> {
328 | let action = self.get_insertion_action(key);
329 |
330 | match action.name {
331 | ActionName::CreateNewNode => {
332 | let new_node = Self::new(key.to_string(), None);
333 |
334 | match key {
335 | k if k == "*" => {
336 | self.child_nodes.push(new_node);
337 |
338 | if let Some(node) = self.child_nodes.last_mut() {
339 | return Ok(node);
340 | };
341 | }
342 | _ => {
343 | self.child_nodes.insert(0, new_node);
344 |
345 | if let Some(node) = self.child_nodes.first_mut() {
346 | return Ok(node);
347 | };
348 | }
349 | }
350 | }
351 | ActionName::NextNode => {
352 | if let Some(node) = self.child_nodes.get_mut(action.payload.node_index) {
353 | return Ok(node);
354 | };
355 | }
356 | ActionName::SplitKey => {
357 | if let Some(node) = self.child_nodes.get_mut(action.payload.node_index) {
358 | return node.process_insertion(&key[action.payload.match_count..]);
359 | };
360 | }
361 | ActionName::SplitNode => {
362 | if let Some(node) = self.child_nodes.get_mut(action.payload.node_index) {
363 | let count = action.payload.match_count;
364 | let child_key = node.key[count..].to_string();
365 | let new_key = key[count..].to_string();
366 | node.key = key[..count].to_string();
367 |
368 | let mut inter_node = Self::new(child_key, None);
369 |
370 | // Move out the previous child and transfer to intermediate node
371 | inter_node.child_nodes = std::mem::take(&mut node.child_nodes);
372 | inter_node.value = std::mem::replace(&mut node.value, None);
373 |
374 | node.child_nodes.insert(0, inter_node);
375 |
376 | // In the case of insert key length less than matched node key length
377 | if new_key.is_empty() {
378 | return Ok(node);
379 | }
380 |
381 | let new_node = Self::new(new_key, None);
382 |
383 | node.child_nodes.insert(0, new_node);
384 | if let Some(result_node) = node.child_nodes.first_mut() {
385 | return Ok(result_node);
386 | }
387 | };
388 | }
389 | ActionName::Error => {
390 | if let Some(node) = self.child_nodes.get(action.payload.node_index) {
391 | return Err(ObsidianError::GeneralError(format!(
392 | "ERROR: Ambigous definition between {} and {}",
393 | key, node.key
394 | )));
395 | }
396 | }
397 | }
398 |
399 | unreachable!();
400 | }
401 |
402 | /// Determine the action required to be performed for the new route path
403 | fn get_insertion_action(&self, key: &str) -> Action {
404 | for (index, node) in self.child_nodes.iter().enumerate() {
405 | let is_param = node.is_param() || key.chars().next().unwrap_or(' ') == ':';
406 | if is_param {
407 | // Only allow one param leaf in one children series
408 | if key == node.key {
409 | return Action::new(ActionName::NextNode, ActionPayload::new(0, index));
410 | } else {
411 | return Action::new(ActionName::Error, ActionPayload::new(0, index));
412 | }
413 | }
414 |
415 | let mut temp_key_chars = key.chars();
416 | let mut count = 0;
417 |
418 | // match characters
419 | for k in node.key.chars() {
420 | let t_k = match temp_key_chars.next() {
421 | Some(key) => key,
422 | None => break,
423 | };
424 |
425 | if t_k == k {
426 | count += t_k.len_utf8();
427 | } else {
428 | break;
429 | }
430 | }
431 |
432 | match count {
433 | x if x == key.len() && x == node.key.len() => {
434 | return Action::new(ActionName::NextNode, ActionPayload::new(x, index))
435 | }
436 | x if x == node.key.len() => {
437 | return Action::new(ActionName::SplitKey, ActionPayload::new(x, index))
438 | }
439 | x if x != 0 => {
440 | return Action::new(ActionName::SplitNode, ActionPayload::new(x, index))
441 | }
442 | _ => {}
443 | }
444 | }
445 |
446 | // No child node matched the key, creates new node
447 | Action::new(ActionName::CreateNewNode, ActionPayload::new(0, 0))
448 | }
449 |
450 | // Helper function to consume the whole key and get the next available node
451 | fn get_next_node(
452 | &self,
453 | key: &mut Vec<&str>,
454 | params: &mut HashMap,
455 | middlewares: &mut Vec>,
456 | is_break_parent: bool,
457 | ) -> Option<&Self> {
458 | let curr_key = key.remove(0);
459 |
460 | for node in self.child_nodes.iter() {
461 | let mut break_key = false;
462 |
463 | if !is_break_parent {
464 | // Check param
465 | if node.is_param() {
466 | if key.is_empty() {
467 | match &node.value {
468 | Some(curr_val) => {
469 | params.insert(node.key[1..].to_string(), curr_key.to_string());
470 | middlewares.append(&mut curr_val.middlewares.clone());
471 | return Some(node);
472 | }
473 | None => {
474 | continue;
475 | }
476 | }
477 | } else {
478 | match node.get_next_node(key, params, middlewares, break_key) {
479 | Some(final_val) => {
480 | params.insert(node.key[1..].to_string(), curr_key.to_string());
481 |
482 | match &node.value {
483 | Some(curr_val) => {
484 | middlewares.append(&mut curr_val.middlewares.clone());
485 | }
486 | None => {}
487 | }
488 |
489 | return Some(final_val);
490 | }
491 | None => {
492 | continue;
493 | }
494 | }
495 | }
496 | }
497 |
498 | // Check wildcard
499 | if node.key == "*" {
500 | match &node.value {
501 | Some(curr_val) => {
502 | middlewares.append(&mut curr_val.middlewares.clone());
503 | }
504 | None => {}
505 | }
506 |
507 | return Some(node);
508 | }
509 | }
510 |
511 | let mut temp_key_chars = curr_key.chars();
512 | let mut count = 0;
513 |
514 | // match characters
515 | for k in node.key.chars() {
516 | let t_k = match temp_key_chars.next() {
517 | Some(key) => key,
518 | None => break,
519 | };
520 |
521 | if t_k == k {
522 | count += t_k.len_utf8();
523 | } else {
524 | break;
525 | }
526 | }
527 |
528 | if count == node.key.len() && count != curr_key.len() {
529 | // break key
530 | break_key = true;
531 | key.insert(0, &curr_key[count..]);
532 | }
533 |
534 | if count != 0 && count == node.key.len() {
535 | if key.is_empty() {
536 | match &node.value {
537 | Some(curr_val) => {
538 | middlewares.append(&mut curr_val.middlewares.clone());
539 | return Some(node);
540 | }
541 | None => {
542 | for child in node.child_nodes.iter() {
543 | if child.key == "*" {
544 | if let Some(child_val) = &child.value {
545 | middlewares.append(&mut child_val.middlewares.clone());
546 | return Some(child);
547 | }
548 | }
549 | }
550 |
551 | continue;
552 | }
553 | }
554 | } else if let Some(final_val) =
555 | node.get_next_node(key, params, middlewares, break_key)
556 | {
557 | if let Some(curr_val) = &node.value {
558 | middlewares.append(&mut curr_val.middlewares.clone());
559 | }
560 |
561 | return Some(final_val);
562 | }
563 | }
564 |
565 | continue;
566 | }
567 |
568 | // Not found
569 | None
570 | }
571 | }
572 |
573 | /// Action to be performed by the node
574 | enum ActionName {
575 | NextNode,
576 | CreateNewNode,
577 | SplitNode,
578 | SplitKey,
579 | Error,
580 | }
581 |
582 | /// Action Payload:
583 | /// characters matched for the node key and insert key
584 | /// node index in the node vector
585 | struct ActionPayload {
586 | match_count: usize,
587 | node_index: usize,
588 | }
589 |
590 | /// Container for actions will be performed in the trie
591 | struct Action {
592 | name: ActionName,
593 | payload: ActionPayload,
594 | }
595 |
596 | impl Action {
597 | pub fn new(name: ActionName, payload: ActionPayload) -> Self {
598 | Action { name, payload }
599 | }
600 | }
601 |
602 | impl ActionPayload {
603 | pub fn new(match_count: usize, node_index: usize) -> Self {
604 | ActionPayload {
605 | match_count,
606 | node_index,
607 | }
608 | }
609 | }
610 |
611 | #[cfg(test)]
612 | mod tests {
613 | use super::*;
614 | use crate::context::Context;
615 | use crate::middleware::logger::Logger;
616 | use crate::router::ContextResult;
617 |
618 | async fn handler(ctx: Context) -> ContextResult {
619 | ctx.build("test").ok()
620 | }
621 |
622 | #[test]
623 | fn radix_trie_head_test() {
624 | let mut route_trie = RouteTrie::new();
625 | let logger = Logger::new();
626 |
627 | route_trie.insert_default_middleware(logger);
628 | route_trie.insert_route("/", Route::new(Method::GET, handler));
629 |
630 | let result = route_trie.search_route("/");
631 |
632 | assert!(result.is_some());
633 |
634 | match result {
635 | Some(route) => {
636 | let middlewares = route.get_middlewares();
637 | let route_value = route.get_route(&Method::GET).is_some();
638 |
639 | assert_eq!(middlewares.len(), 1);
640 | assert!(route_value);
641 | }
642 | _ => panic!(),
643 | }
644 | }
645 |
646 | #[test]
647 | fn radix_trie_normal_test() {
648 | let mut route_trie = RouteTrie::new();
649 | let logger = Logger::new();
650 | let logger2 = Logger::new();
651 |
652 | route_trie.insert_default_middleware(logger);
653 | route_trie.insert_route("/normal/test/", Route::new(Method::GET, handler));
654 | route_trie.insert_route("/ノーマル/テスト/", Route::new(Method::GET, handler));
655 | route_trie.insert_middleware("/ノーマル/テスト/", logger2);
656 |
657 | let result = route_trie.search_route("/normal/test/");
658 |
659 | assert!(result.is_some());
660 |
661 | match result {
662 | Some(route) => {
663 | let middlewares = route.get_middlewares();
664 | let route_value = route.get_route(&Method::GET).is_some();
665 |
666 | assert_eq!(middlewares.len(), 1);
667 | assert!(route_value);
668 | }
669 | _ => panic!(),
670 | }
671 |
672 | let result = route_trie.search_route("/ノーマル/テスト/");
673 |
674 | assert!(result.is_some());
675 |
676 | match result {
677 | Some(route) => {
678 | let middlewares = route.get_middlewares();
679 | let route_value = route.get_route(&Method::GET).is_some();
680 |
681 | assert_eq!(middlewares.len(), 2);
682 | assert!(route_value);
683 | }
684 | _ => panic!(),
685 | }
686 | }
687 |
688 | #[test]
689 | fn radix_trie_not_found_test() {
690 | let mut route_trie = RouteTrie::new();
691 | let logger = Logger::new();
692 |
693 | route_trie.insert_default_middleware(logger);
694 | route_trie.insert_route("/normal/test/", Route::new(Method::GET, handler));
695 |
696 | let result = route_trie.search_route("/fail/test/");
697 |
698 | assert!(result.is_none());
699 | }
700 |
701 | #[test]
702 | fn radix_trie_split_node_and_key_test() {
703 | let mut route_trie = RouteTrie::new();
704 | let logger = Logger::new();
705 | let logger2 = Logger::new();
706 | let logger3 = Logger::new();
707 |
708 | route_trie.insert_default_middleware(logger);
709 | route_trie.insert_route("/normal/test/", Route::new(Method::GET, handler));
710 | route_trie.insert_route("/noral/test/", Route::new(Method::GET, handler));
711 | route_trie.insert_route("/ノーマル/テスト/", Route::new(Method::GET, handler));
712 | route_trie.insert_route("/ノーマル/テーブル/", Route::new(Method::GET, handler));
713 | route_trie.insert_middleware("/noral/test/", logger2);
714 | route_trie.insert_middleware("/ノーマル/テーブル/", logger3);
715 |
716 | let test_cases = vec![
717 | ("/normal/test/", 1),
718 | ("/noral/test/", 2),
719 | ("/ノーマル/テスト/", 1),
720 | ("/ノーマル/テーブル/", 2),
721 | ];
722 |
723 | for case in test_cases.iter() {
724 | let normal_result = route_trie.search_route(case.0);
725 |
726 | assert!(normal_result.is_some());
727 |
728 | match normal_result {
729 | Some(route) => {
730 | let middlewares = route.get_middlewares();
731 | let route_value = route.get_route(&Method::GET).is_some();
732 |
733 | assert_eq!(middlewares.len(), case.1);
734 | assert!(route_value);
735 | }
736 | _ => panic!(),
737 | }
738 | }
739 | }
740 |
741 | #[test]
742 | fn radix_trie_wildcard_test() {
743 | let mut route_trie = RouteTrie::new();
744 | let logger = Logger::new();
745 | let logger2 = Logger::new();
746 | let logger3 = Logger::new();
747 |
748 | route_trie.insert_route("/normal/test/*", Route::new(Method::GET, handler));
749 | route_trie.insert_middleware("/normal/test/*", logger);
750 | route_trie.insert_middleware("/normal/test/*", logger2);
751 | route_trie.insert_middleware("/normal/test/*", logger3);
752 |
753 | let test_cases = vec![
754 | "/normal/test/test",
755 | "/normal/test/123",
756 | "/normal/test/こんにちは",
757 | "/normal/test/啊",
758 | ];
759 |
760 | for case in test_cases.iter() {
761 | let normal_result = route_trie.search_route(case);
762 |
763 | assert!(normal_result.is_some());
764 |
765 | match normal_result {
766 | Some(route) => {
767 | let middlewares = route.get_middlewares();
768 | let route_value = route.get_route(&Method::GET).is_some();
769 |
770 | assert_eq!(middlewares.len(), 3);
771 | assert!(route_value);
772 | }
773 | _ => panic!(),
774 | }
775 | }
776 | }
777 |
778 | #[should_panic]
779 | #[test]
780 | fn radix_trie_wildcard_param_conflict_test() {
781 | let mut route_trie = RouteTrie::new();
782 |
783 | route_trie.insert_route("/normal/test/*", Route::new(Method::GET, handler));
784 | route_trie.insert_route("/normal/test/:param", Route::new(Method::GET, handler));
785 | }
786 |
787 | #[should_panic]
788 | #[test]
789 | fn radix_trie_param_wildcard_conflict_test() {
790 | let mut route_trie = RouteTrie::new();
791 |
792 | route_trie.insert_route("/normal/test/:param", Route::new(Method::GET, handler));
793 | route_trie.insert_route("/normal/test/*", Route::new(Method::GET, handler));
794 | }
795 | }
796 |
--------------------------------------------------------------------------------