├── logo.png ├── .gitignore ├── src ├── core │ ├── mod.rs │ ├── middleware │ │ └── mod.rs │ ├── shortcuts.rs │ ├── logging.rs │ ├── server │ │ └── utils.rs │ ├── path │ │ └── mod.rs │ ├── cookie │ │ └── mod.rs │ ├── headers │ │ └── mod.rs │ ├── forms │ │ └── mod.rs │ ├── parser │ │ ├── urlencoded.rs │ │ ├── mod.rs │ │ └── multipart.rs │ ├── request │ │ └── mod.rs │ ├── session │ │ ├── mod.rs │ │ └── managers.rs │ ├── response │ │ ├── mod.rs │ │ └── status.rs │ ├── websocket │ │ ├── mod.rs │ │ └── frame.rs │ └── stream │ │ └── mod.rs ├── prelude.rs ├── forms │ ├── fields │ │ ├── mod.rs │ │ ├── uuid_field.rs │ │ ├── file_field.rs │ │ └── input_field.rs │ └── mod.rs └── lib.rs ├── .github └── workflows │ └── build.yml ├── Cargo.toml ├── LICENSE.md └── README.md /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/racoonframework/racoon/HEAD/logo.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /ignored 3 | /src/main.rs 4 | /templates 5 | .idea/ 6 | .cache/ 7 | .env 8 | -------------------------------------------------------------------------------- /src/core/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod request; 2 | pub mod cookie; 3 | pub mod session; 4 | pub mod path; 5 | pub mod server; 6 | pub mod response; 7 | pub mod parser; 8 | pub mod stream; 9 | pub mod logging; 10 | pub mod middleware; 11 | pub mod headers; 12 | pub mod forms; 13 | 14 | pub mod websocket; 15 | pub mod shortcuts; 16 | -------------------------------------------------------------------------------- /src/prelude.rs: -------------------------------------------------------------------------------- 1 | pub use crate::core::forms::FileFieldShortcut; 2 | pub use crate::core::request::Request; 3 | pub use crate::core::response::Response; 4 | pub use crate::core::response::status::ResponseStatus; 5 | pub use crate::core::response::HttpResponse; 6 | pub use crate::core::response::JsonResponse; 7 | pub use crate::core::path::Path; 8 | pub use crate::core::shortcuts::SingleText; 9 | pub use crate::core::server::Server; 10 | pub use crate::view; 11 | pub use crate::wrap_view; 12 | -------------------------------------------------------------------------------- /src/core/middleware/mod.rs: -------------------------------------------------------------------------------- 1 | use std::future::Future; 2 | use std::pin::Pin; 3 | 4 | use crate::core::path::View; 5 | use crate::core::request::Request; 6 | use crate::core::response::AbstractResponse; 7 | 8 | pub type Middleware = fn(Request, Option) -> Pin> + Send>>; 9 | 10 | #[macro_export] 11 | macro_rules! wrap_view { 12 | ($middleware_fn: ident) => { 13 | |request: Request, view: Option| Box::pin($middleware_fn(request, view)) 14 | } 15 | } -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | paths: 7 | - 'src/**' 8 | - '.github/workflows/**' 9 | - 'Cargo.toml' 10 | 11 | pull_request: 12 | branches: [ "main" ] 13 | paths: 14 | - 'src/**' 15 | - '.github/workflows/**' 16 | - 'Cargo.toml' 17 | 18 | env: 19 | CARGO_TERM_COLOR: always 20 | 21 | jobs: 22 | build: 23 | 24 | runs-on: ubuntu-latest 25 | 26 | steps: 27 | - uses: actions/checkout@v4 28 | - name: Build 29 | run: cargo build --verbose 30 | - name: Run tests 31 | run: cargo test --verbose 32 | -------------------------------------------------------------------------------- /src/forms/fields/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod file_field; 2 | pub mod input_field; 3 | pub mod uuid_field; 4 | 5 | use std::future::Future; 6 | 7 | use crate::core::forms::{Files, FormData}; 8 | 9 | type FieldResult = Box + Send + Sync + Unpin>; 10 | 11 | pub trait AbstractFields: Sync + Send { 12 | fn field_name(&self) -> FieldResult; 13 | fn validate( 14 | &mut self, 15 | form_data: &mut FormData, 16 | files: &mut Files, 17 | ) -> FieldResult>>; 18 | fn wrap(&self) -> Box; 19 | } 20 | 21 | pub type FormFields = Vec>; 22 | 23 | pub enum FieldError { 24 | Message(Vec), 25 | } 26 | -------------------------------------------------------------------------------- /src/core/shortcuts.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | pub trait SingleText { 4 | /// Performs case-insensitive lookup and returns first value found. 5 | fn value>(&self, name: S) -> Option<&String>; 6 | } 7 | 8 | 9 | impl SingleText for HashMap> { 10 | fn value>(&self, name: S) -> Option<&String> { 11 | let name = name.as_ref(); 12 | 13 | for (key, values) in self.iter() { 14 | if key.to_lowercase() != name.to_lowercase() { 15 | continue; 16 | } 17 | 18 | if let Some(field) = values.get(0) { 19 | return Some(field); 20 | } 21 | } 22 | None 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! Racoon is a fast, fully customizable web framework for Rust focusing on simplicity. 3 | //! 4 | //! To use Racoon, you need minimal Rust version 1.75.0 and Tokio runtime. 5 | //! 6 | //! Getting started: 7 | //! ```rust,no_run 8 | //! use racoon::core::path::Path; 9 | //! use racoon::core::request::Request; 10 | //! use racoon::core::response::{HttpResponse, Response}; 11 | //! use racoon::core::response::status::ResponseStatus; 12 | //! use racoon::core::server::Server; 13 | //! 14 | //! use racoon::view; 15 | //! 16 | //! async fn home(request: Request) -> Response { 17 | //! HttpResponse::ok().body("Home") 18 | //! } 19 | //! 20 | //! #[tokio::main] 21 | //! async fn main() { 22 | //! let paths = vec![ 23 | //! Path::new("/", view!(home)) 24 | //! ]; 25 | //! 26 | //! let result = Server::bind("127.0.0.1:8080") 27 | //! .urls(paths) 28 | //! .run().await; 29 | //! 30 | //! println!("Failed to run server: {:?}", result); 31 | //! } 32 | //! ``` 33 | //! 34 | 35 | pub mod core; 36 | pub mod forms; 37 | pub mod prelude; 38 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "racoon" 3 | version = "0.1.9" 4 | edition = "2021" 5 | authors = ["Tej Magar"] 6 | description = "Racoon is a fast, fully customizable web framework for Rust focusing on simplicity." 7 | license = "MIT" 8 | readme = "README.md" 9 | documentation = "https://racoonframework.github.io" 10 | repository = "https://github.com/racoonframework/racoon/" 11 | keywords = ["web", "framework", "racoon", "http"] 12 | categories = ["web-programming", "web-programming::http-server", "web-programming::websocket"] 13 | 14 | [dependencies] 15 | matchit = "0.8.0" 16 | tokio = { version = "1.38.0", features = ["full"] } 17 | log = "0.4.21" 18 | regex = "1.10.4" 19 | httparse = "1.8.0" 20 | urlencoding = "2.1.3" 21 | sha1 = "0.10.6" 22 | base64 = "0.22.0" 23 | serde = "1.0.199" 24 | serde_json = "1.0.116" 25 | uuid = { version = "1.8.0", features = ["v4"] } 26 | tokio-rustls = "0.26.0" 27 | rustls = "0.23.8" 28 | rustls-pemfile = "2.1.2" 29 | chrono = "0.4.38" 30 | sqlx = {version = "0.8.3", features=["runtime-tokio", "sqlite"]} 31 | rand = "0.9.0" 32 | async-tempfile = "0.6.0" 33 | 34 | [dev-dependencies] 35 | 36 | 37 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 racoonframework 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/core/logging.rs: -------------------------------------------------------------------------------- 1 | pub mod condition { 2 | use std::env; 3 | 4 | pub fn is_logging_enabled() -> bool { 5 | return match env::var("RACOON_LOGGING") { 6 | Ok(value) => { 7 | value.to_lowercase() == "true" 8 | } 9 | Err(_) => { 10 | false 11 | } 12 | }; 13 | } 14 | } 15 | 16 | #[macro_export] 17 | macro_rules! racoon_debug { 18 | ($($arg:tt)*) => { 19 | if crate::core::logging::condition::is_logging_enabled() { 20 | log::debug!($($arg)*); 21 | } 22 | } 23 | } 24 | 25 | #[macro_export] 26 | macro_rules! racoon_info { 27 | ($($arg:tt)*) => { 28 | if crate::core::logging::condition::is_logging_enabled() { 29 | log::info!($($arg)*); 30 | } 31 | } 32 | } 33 | 34 | #[macro_export] 35 | macro_rules! racoon_warn { 36 | ($($arg:tt)*) => { 37 | if use crate::core::logging::condition::is_logging_enabled() { 38 | log::warn!($($arg)*); 39 | } 40 | } 41 | } 42 | 43 | #[macro_export] 44 | macro_rules! racoon_trace { 45 | ($($arg:tt)*) => { 46 | if use crate::core::logging::condition::is_logging_enabled() { 47 | log::trace!($($arg)*); 48 | } 49 | } 50 | } 51 | 52 | #[macro_export] 53 | macro_rules! racoon_error { 54 | ($($arg:tt)*) => { 55 | if crate::core::logging::condition::is_logging_enabled() { 56 | log::error!($($arg)*); 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/core/server/utils.rs: -------------------------------------------------------------------------------- 1 | use std::{ffi::OsStr, io::BufReader}; 2 | use std::sync::Arc; 3 | 4 | use rustls_pemfile::{certs, pkcs8_private_keys}; 5 | use tokio_rustls::TlsAcceptor; 6 | 7 | use crate::racoon_info; 8 | 9 | pub fn tls_acceptor_from_path>( 10 | certificate_path: S, 11 | private_key_path: S, 12 | ) -> std::io::Result { 13 | // Tries to read certificate file 14 | let certificate_file = match std::fs::File::open(certificate_path.as_ref()) { 15 | Ok(file) => file, 16 | Err(error) => { 17 | return Err(std::io::Error::other(format!( 18 | "Failed to open certificate file. Error: {}", 19 | error 20 | ))); 21 | } 22 | }; 23 | 24 | let mut certificate_buffered_reader = BufReader::new(certificate_file); 25 | 26 | // Extracts certificates 27 | let mut certificates = vec![]; 28 | for certificate in certs(&mut certificate_buffered_reader) { 29 | certificates.push(certificate?); 30 | } 31 | 32 | racoon_info!("Found certificates: {}", certificates.len()); 33 | 34 | // Tries to read private key file 35 | let private_key_file = match std::fs::File::open(private_key_path.as_ref()) { 36 | Ok(file) => file, 37 | Err(error) => { 38 | return Err(std::io::Error::other(format!( 39 | "Failed to open private key file. Error: {}", 40 | error 41 | ))); 42 | } 43 | }; 44 | 45 | let mut private_key_buffered_reader = BufReader::new(private_key_file); 46 | 47 | // Extracts private key 48 | let key_options = pkcs8_private_keys(&mut private_key_buffered_reader).next(); 49 | if let Some(key) = key_options { 50 | let private_key_der = rustls::pki_types::PrivateKeyDer::Pkcs8(key?); 51 | let server_config_result = rustls::ServerConfig::builder() 52 | .with_no_client_auth() 53 | .with_single_cert(certificates, private_key_der); 54 | 55 | let server_config = match server_config_result { 56 | Ok(config) => config, 57 | Err(error) => { 58 | return Err(std::io::Error::other(format!( 59 | "Failed to create server configuraiton. Error: {}", 60 | error 61 | ))); 62 | } 63 | }; 64 | 65 | return Ok(TlsAcceptor::from(Arc::new(server_config))); 66 | } else { 67 | return Err(std::io::Error::other("Private key not found.")); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/core/path/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::future::Future; 3 | use std::pin::Pin; 4 | 5 | use crate::core::request::Request; 6 | use crate::core::response::status::ResponseStatus; 7 | use crate::core::response::{AbstractResponse, HttpResponse, Response}; 8 | use crate::core::shortcuts::SingleText; 9 | 10 | use super::headers::HeaderValue; 11 | 12 | pub type View = fn(Request) -> Pin> + Send>>; 13 | 14 | pub struct Path { 15 | pub name: String, 16 | pub view: View, 17 | } 18 | 19 | impl Path { 20 | pub fn new>(name: S, view: View) -> Self { 21 | Self { 22 | name: name.as_ref().to_string(), 23 | view, 24 | } 25 | } 26 | 27 | pub async fn resolve(request: Request, view: Option) -> Response { 28 | let mut response; 29 | let response_headers_from_request_ref = request.response_headers.clone(); 30 | 31 | if let Some(view) = view { 32 | response = view(request).await; 33 | } else { 34 | response = HttpResponse::not_found().body("404 Page not found"); 35 | } 36 | 37 | // Adds additional headers received from request struct. 38 | 39 | let response_headers_from_request = response_headers_from_request_ref.lock().await; 40 | let response_headers = response.get_headers(); 41 | 42 | for (name, values) in response_headers_from_request.iter() { 43 | for value in values { 44 | response_headers.set_multiple(name, value); 45 | } 46 | } 47 | response 48 | } 49 | } 50 | 51 | impl Clone for Path { 52 | fn clone(&self) -> Self { 53 | Self { 54 | name: self.name.clone(), 55 | view: self.view.clone(), 56 | } 57 | } 58 | } 59 | 60 | pub type Paths = Vec; 61 | 62 | #[derive(Debug)] 63 | pub struct PathParams { 64 | params: HashMap, 65 | } 66 | 67 | impl Clone for PathParams { 68 | fn clone(&self) -> Self { 69 | Self { 70 | params: self.params.clone(), 71 | } 72 | } 73 | } 74 | 75 | impl SingleText for PathParams { 76 | fn value>(&self, name: S) -> Option<&String> { 77 | let name = name.as_ref(); 78 | self.params.get(name) 79 | } 80 | } 81 | 82 | impl PathParams { 83 | pub fn new() -> Self { 84 | Self { 85 | params: HashMap::new(), 86 | } 87 | } 88 | 89 | pub fn insert(&mut self, key: &str, value: &str) { 90 | self.params.insert(key.to_owned(), value.to_owned()); 91 | } 92 | 93 | pub fn map(&mut self) -> &mut HashMap { 94 | &mut self.params 95 | } 96 | } 97 | 98 | #[macro_export] 99 | macro_rules! view { 100 | ($view_name: ident) => { 101 | |request: racoon::core::request::Request| Box::pin($view_name(request)) 102 | }; 103 | } 104 | -------------------------------------------------------------------------------- /src/core/cookie/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::time::{Duration, SystemTime}; 3 | 4 | use chrono::{DateTime, Utc}; 5 | 6 | use crate::core::headers::{HeaderValue, Headers}; 7 | use crate::core::shortcuts::SingleText; 8 | 9 | pub type Cookies = HashMap; 10 | 11 | impl SingleText for Cookies { 12 | fn value>(&self, name: S) -> Option<&String> { 13 | let name = name.as_ref(); 14 | 15 | for cookie_name in self.keys() { 16 | if cookie_name.to_lowercase() != name.to_lowercase() { 17 | continue; 18 | } 19 | 20 | let value = self.get(cookie_name); 21 | return value; 22 | } 23 | 24 | None 25 | } 26 | } 27 | 28 | /// 29 | /// Returns HashMap of type Cookies from passed headers. 30 | /// 31 | pub fn parse_cookies_from_header(headers: &Headers) -> Cookies { 32 | // Reads Cookie header value from multiple header lines. 33 | // Example: 34 | // Cookie: name=John; 35 | // Cookie: location=ktm; 36 | let cookie_headers = headers.multiple_values("cookie"); 37 | let mut cookies = Cookies::new(); 38 | 39 | // Looping through multiple "Cookie: ..." headers. 40 | for cookie_header_value in cookie_headers { 41 | parse_cookie_header_value(cookie_header_value, &mut cookies); 42 | } 43 | 44 | cookies 45 | } 46 | 47 | /// 48 | /// # Example 49 | /// 50 | /// ``` 51 | /// use racoon::core::cookie::parse_cookie_header_value; 52 | /// use racoon::core::cookie::Cookies; 53 | /// use racoon::core::shortcuts::SingleText; 54 | /// 55 | /// // Requires only value from "Cookie: name=John; location=Ktm" 56 | /// let cookie_header_value = "name=John; location=Ktm".to_string(); 57 | /// let mut cookies = Cookies::new(); 58 | /// 59 | /// parse_cookie_header_value(cookie_header_value, &mut cookies); 60 | /// 61 | /// let name = cookies.value("name"); 62 | /// let location = cookies.value("location"); 63 | /// let unknown = cookies.value("unknown"); 64 | /// 65 | /// assert_eq!(name, Some(&"John".to_string())); 66 | /// assert_eq!(location, Some(&"Ktm".to_string())); 67 | /// assert_eq!(unknown, None); 68 | /// ``` 69 | /// 70 | pub fn parse_cookie_header_value(cookie_header_value: String, cookies: &mut Cookies) { 71 | // Single Cookie header value contains multiple key value pairs seperated by comma. 72 | let raw_key_values: Vec<&str> = cookie_header_value.split(";").collect(); 73 | 74 | for raw_value in raw_key_values { 75 | let key_value: Vec<&str> = (*raw_value).splitn(2, "=").collect(); 76 | 77 | if key_value.len() >= 2 { 78 | let raw_key = key_value[0].trim(); 79 | // If url decoding fails, raw values are used. 80 | let key = match urlencoding::decode(raw_key) { 81 | Ok(decoded) => decoded.to_string(), 82 | Err(_) => raw_key.to_string(), 83 | }; 84 | 85 | let raw_value = key_value[1].trim(); 86 | let value = match urlencoding::decode(raw_value) { 87 | Ok(decoded) => decoded.to_string(), 88 | Err(_) => raw_value.to_string(), 89 | }; 90 | 91 | cookies.insert(key, value); 92 | } 93 | } 94 | } 95 | 96 | pub fn set_cookie>(headers: &mut Headers, name: S, value: S, max_age: Duration) { 97 | let now = SystemTime::now(); 98 | let expire_time = now + max_age; 99 | let datetime = DateTime::::from(expire_time); 100 | let expires_date = datetime.format("%a, %d-%b-%Y %H:%M:%S GMT"); 101 | 102 | let encoded_name = urlencoding::encode(name.as_ref()); 103 | let encoded_value = urlencoding::encode(value.as_ref()); 104 | 105 | let header_value = format!( 106 | "{}={}; Expires={}; Path=/; HttpOnly", 107 | encoded_name, encoded_value, expires_date 108 | ); 109 | headers.set_multiple("Set-Cookie", header_value); 110 | } 111 | 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Racoon 2 | 3 |

4 | 5 |
6 | 7 | Racoon 8 | 9 |

10 | 11 | 12 | Racoon is a fast, fully customizable web framework for Rust focusing on simplicity. 13 | 14 | To use Racoon, you need minimal Rust version `1.75.0` and `Tokio` runtime. 15 | 16 | 17 | [Learn Racoon](https://racoonframework.github.io) 18 | 19 | ## Installation 20 | 21 | You will need `tokio` runtime to run Racoon. Run `cargo add tokio` to install tokio crate. 22 | 23 | ```toml 24 | [dependencies] 25 | racoon = "0.1.9" 26 | ``` 27 | 28 | ## Basic Usage 29 | 30 | ```rust 31 | use racoon::core::path::Path; 32 | use racoon::core::request::Request; 33 | use racoon::core::response::{HttpResponse, Response}; 34 | use racoon::core::response::status::ResponseStatus; 35 | use racoon::core::server::Server; 36 | 37 | use racoon::view; 38 | 39 | async fn home(request: Request) -> Response { 40 | HttpResponse::ok().body("Home") 41 | } 42 | 43 | #[tokio::main] 44 | async fn main() { 45 | let paths = vec![ 46 | Path::new("/", view!(home)) 47 | ]; 48 | 49 | let result = Server::bind("127.0.0.1:8080") 50 | .urls(paths) 51 | .run().await; 52 | 53 | println!("Failed to run server: {:?}", result); 54 | } 55 | ``` 56 | 57 | ### File Handling 58 | 59 | There are multiple ways to handle files in Racoon. The simple way is to use `request.parse()` method. 60 | 61 | ```rust 62 | use racoon::core::request::Request; 63 | use racoon::core::response::{HttpResponse, Response}; 64 | use racoon::core::response::status::ResponseStatus; 65 | use racoon::core::forms::FileField; 66 | use racoon::core::shortcuts::SingleText; 67 | 68 | async fn upload_form(request: Request) -> Response { 69 | if request.method == "POST" { 70 | // Parses request body 71 | let (form_data, files) = request.parse().await; 72 | println!("Name: {:?}", form_data.value("name")); 73 | 74 | let file = files.value("file"); 75 | println!("File: {:?}", file); 76 | return HttpResponse::ok().body("Uploaded"); 77 | } 78 | 79 | HttpResponse::bad_request().body("Use POST method to upload file.") 80 | } 81 | ``` 82 | 83 | For more information check [form handling guide](https://racoonframework.github.io/reading-form-data/). 84 | 85 | ## WebSocket example 86 | 87 | ```rust 88 | use racoon::core::path::Path; 89 | use racoon::core::request::Request; 90 | use racoon::core::response::Response; 91 | use racoon::core::server::Server; 92 | use racoon::core::websocket::{Message, WebSocket}; 93 | 94 | use racoon::view; 95 | 96 | async fn ws(request: Request) -> Response { 97 | let (websocket, connected) = WebSocket::from(&request).await; 98 | if !connected { 99 | // WebSocket connection didn't success 100 | return websocket.bad_request().await; 101 | } 102 | 103 | println!("WebSocket client connected."); 104 | 105 | // Receive incoming messages 106 | while let Some(message) = websocket.message().await { 107 | match message { 108 | Message::Text(text) => { 109 | println!("Message: {}", text); 110 | 111 | // Sends received message back 112 | let _ = websocket.send_text(text.as_str()).await; 113 | } 114 | _ => {} 115 | } 116 | } 117 | websocket.exit() 118 | } 119 | 120 | #[tokio::main] 121 | async fn main() { 122 | let paths = vec![ 123 | Path::new("/ws/", view!(ws)) 124 | ]; 125 | 126 | let _ = Server::bind("127.0.0.1:8080") 127 | .urls(paths) 128 | .run().await; 129 | } 130 | ``` 131 | 132 | 133 | ## Benchmark 134 | 135 | ```shell 136 | wrk -c100 -d8s -t4 http://127.0.0.1:8080 137 | ``` 138 | 139 | Result on AMD Ryzen 5 7520U with Radeon Graphics. 140 | 141 | ```text 142 | Running 8s test @ http://127.0.0.1:8080/ 143 | 4 threads and 100 connections 144 | Thread Stats Avg Stdev Max +/- Stdev 145 | Latency 374.62us 219.91us 3.91ms 76.47% 146 | Req/Sec 62.42k 4.25k 70.53k 82.50% 147 | 1987462 requests in 8.00s, 140.26MB read 148 | Requests/sec: 248389.96 149 | Transfer/sec: 17.53MB 150 | ``` 151 | 152 | This benchmark does not make sense in real world. 153 | 154 | -------------------------------------------------------------------------------- /src/core/headers/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | pub type Headers = HashMap>>; 4 | 5 | pub trait HeaderValue { 6 | /// Performs case-insensitive lookup and returns first value. 7 | fn value>(&self, name: S) -> Option; 8 | 9 | /// Performs case-insensitive lookup and returns multiple values. 10 | fn multiple_values>(&self, name: S) -> Vec; 11 | 12 | /// Inserts new header and makes sure there will be only one header with the given name. 13 | fn set>(&mut self, name: &str, value: B); 14 | 15 | /// Inserts new headers and allows to have multiple headers with the same name. 16 | fn set_multiple>(&mut self, name: &str, value: B); 17 | } 18 | 19 | impl HeaderValue for Headers { 20 | fn value>(&self, name: S) -> Option { 21 | let name = name.as_ref(); 22 | 23 | for (key, values) in self.iter() { 24 | if key.to_lowercase() != name.to_lowercase() { 25 | continue; 26 | } 27 | 28 | if let Some(value_bytes) = values.get(0) { 29 | let value = String::from_utf8_lossy(value_bytes); 30 | return Some(value.to_string()); 31 | } 32 | } 33 | 34 | None 35 | } 36 | 37 | fn multiple_values>(&self, name: S) -> Vec { 38 | let name = name.as_ref(); 39 | 40 | let mut multiple_headers = vec![]; 41 | 42 | for (key, values) in self.iter() { 43 | if key.to_lowercase() != name.to_lowercase() { 44 | continue; 45 | } 46 | 47 | if let Some(value_bytes) = values.get(0) { 48 | let value = String::from_utf8_lossy(value_bytes); 49 | multiple_headers.push(value.to_string()); 50 | } 51 | } 52 | 53 | multiple_headers 54 | } 55 | 56 | fn set>(&mut self, name: &str, value: B) { 57 | let value = value.as_ref(); 58 | 59 | if let Some(values) = self.get_mut(&name.to_string()) { 60 | if values.len() > 0 { 61 | values.clear(); 62 | } 63 | 64 | values.push(value.to_vec()); 65 | } else { 66 | self.insert(name.to_string(), vec![value.to_vec()]); 67 | }; 68 | } 69 | 70 | fn set_multiple>(&mut self, name: &str, value: B) { 71 | let value = value.as_ref(); 72 | 73 | if let Some(values) = self.get_mut(&name.to_string()) { 74 | values.push(value.to_vec()); 75 | } else { 76 | self.insert(name.to_string(), vec![value.to_vec()]); 77 | }; 78 | } 79 | } 80 | 81 | /// 82 | /// # Example 83 | /// 84 | /// ``` 85 | /// use racoon::core::headers::multipart_boundary; 86 | /// 87 | /// let boundary_string = "application/form_data; boundary=----123456"; 88 | /// 89 | /// assert_eq!(multipart_boundary(&boundary_string.to_string()).is_ok(), true); 90 | /// assert_eq!(multipart_boundary(&boundary_string.to_string()).unwrap(), "----123456"); 91 | /// 92 | /// ``` 93 | /// 94 | pub fn multipart_boundary(content_type: &String) -> std::io::Result { 95 | let value: Vec<&str> = content_type.split(";").collect(); 96 | 97 | if value.len() >= 2 { 98 | let content_type_text = value.get(1).unwrap().trim(); 99 | let boundary = content_type_text.strip_prefix("boundary=").unwrap(); 100 | return Ok(boundary.to_string()); 101 | } 102 | 103 | return Err(std::io::Error::other("Boundary missing.")); 104 | } 105 | 106 | #[cfg(test)] 107 | pub mod tests { 108 | use crate::core::headers::{multipart_boundary, HeaderValue, Headers}; 109 | 110 | #[test] 111 | pub fn test_header_value() { 112 | let mut headers = Headers::new(); 113 | headers.set("Content-Type", b"text/html"); 114 | 115 | // Case-insensitive 116 | assert_eq!(headers.value("content-Type").is_some(), true); 117 | assert_eq!( 118 | headers.value("content-Type").unwrap(), 119 | "text/html".to_string() 120 | ); 121 | 122 | // Case sensitive 123 | assert_eq!(headers.get("content-Type").is_some(), false); 124 | } 125 | 126 | #[test] 127 | pub fn test_multipart_boundary() { 128 | let boundary_string = "application/form_data; boundary=----123456"; 129 | 130 | assert_eq!( 131 | multipart_boundary(&boundary_string.to_string()).is_ok(), 132 | true 133 | ); 134 | assert_eq!( 135 | multipart_boundary(&boundary_string.to_string()).unwrap(), 136 | "----123456" 137 | ); 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /src/core/forms/mod.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, path::PathBuf}; 2 | 3 | use async_tempfile::TempFile; 4 | 5 | #[derive(Debug)] 6 | pub struct FileField { 7 | pub name: String, 8 | temp_file: TempFile, 9 | pub temp_path: PathBuf, 10 | } 11 | 12 | impl FileField { 13 | pub fn from>(name: S, temp_file: TempFile) -> Self { 14 | let temp_path = temp_file.file_path().clone(); 15 | 16 | Self { 17 | name: name.as_ref().to_string(), 18 | temp_file, 19 | temp_path, 20 | } 21 | } 22 | 23 | pub fn temp_file(&self) -> &TempFile { 24 | &self.temp_file 25 | } 26 | } 27 | 28 | pub type Files = HashMap>; 29 | pub type FormData = HashMap>; 30 | 31 | pub trait FileFieldShortcut { 32 | /// Performs case-insensitive lookup and returns first file found. 33 | fn value>(&self, name: S) -> Option<&FileField>; 34 | } 35 | 36 | impl FileFieldShortcut for Files { 37 | fn value>(&self, name: S) -> Option<&FileField> { 38 | let name = name.as_ref(); 39 | 40 | for (key, values) in self.iter() { 41 | if key.to_lowercase() != name.to_lowercase() { 42 | continue; 43 | } 44 | 45 | if let Some(field) = values.get(0) { 46 | return Some(field); 47 | } 48 | } 49 | None 50 | } 51 | } 52 | 53 | /// 54 | /// The form constraint works as a security measure while parsing request body. 55 | /// It can be set globally while creating the `Server` instance. 56 | /// 57 | /// # Example 58 | /// 59 | /// ```markdown 60 | /// 61 | /// Server::bind("127.0.0.1:8080") 62 | /// .urls(paths) 63 | /// .form_constraints(FormConstraints {...}) 64 | /// .run().await; 65 | /// ``` 66 | /// 67 | pub struct FormConstraints { 68 | /// Maximum allowed body size. 69 | max_body_size: usize, 70 | /// Maximum allowed form part header size. 71 | max_header_size: usize, 72 | /// Maximum allowed form part file size. 73 | max_file_size: usize, 74 | /// Maximum allowed form field value size. 75 | max_value_size: usize, 76 | /// Map of field name and maximum allowed size. 77 | custom_max_sizes: HashMap, 78 | } 79 | 80 | impl FormConstraints { 81 | pub fn new( 82 | max_body_size: usize, 83 | max_header_size: usize, 84 | max_file_size: usize, 85 | max_value_size: usize, 86 | custom_max_sizes: HashMap, 87 | ) -> Self { 88 | Self { 89 | max_body_size, 90 | max_header_size, 91 | max_file_size, 92 | max_value_size, 93 | custom_max_sizes, 94 | } 95 | } 96 | 97 | pub fn max_body_size(&self, buffer_size: usize) -> usize { 98 | if buffer_size > self.max_body_size { 99 | return buffer_size; 100 | } 101 | 102 | // Default size 103 | self.max_body_size 104 | } 105 | 106 | pub fn max_header_size(&self, buffer_size: usize) -> usize { 107 | if buffer_size > self.max_header_size { 108 | return buffer_size; 109 | } 110 | 111 | // Default size 112 | self.max_header_size 113 | } 114 | 115 | pub fn max_value_size(&self, buffer_size: usize) -> usize { 116 | if buffer_size > self.max_value_size { 117 | return buffer_size; 118 | } 119 | 120 | // Default size 121 | self.max_value_size 122 | } 123 | pub fn max_size_for_field(&self, field_name: &String, buffer_size: usize) -> usize { 124 | if let Some(max_size) = self.custom_max_sizes.get(field_name) { 125 | if buffer_size > *max_size { 126 | return buffer_size; 127 | } 128 | return max_size.to_owned(); 129 | } 130 | 131 | // Default size 132 | return self.max_value_size; 133 | } 134 | 135 | pub fn max_size_for_file(&self, field_name: &String, buffer_size: usize) -> usize { 136 | if let Some(max_size) = self.custom_max_sizes.get(field_name) { 137 | if buffer_size > *max_size { 138 | return buffer_size; 139 | } 140 | return max_size.to_owned(); 141 | } 142 | 143 | // Default size 144 | return self.max_file_size; 145 | } 146 | } 147 | 148 | #[derive(Debug)] 149 | pub enum FormFieldError { 150 | /// Max form part body size exceeded. 151 | MaxBodySizeExceed, 152 | /// Maximum form part header size exceeded. 153 | MaxHeaderSizeExceed, 154 | /// Maximum file size exceeded. 155 | MaxFileSizeExceed(String), 156 | /// Maximum length of text length exceeded. 157 | MaxValueSizeExceed(String), 158 | /// (field_name, error, is_criticial) 159 | /// If error is critical, don't expose to client. 160 | Others(Option, String, bool), 161 | } 162 | -------------------------------------------------------------------------------- /src/core/parser/urlencoded.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::sync::Arc; 3 | 4 | use crate::core::forms::{FormConstraints, FormData, FormFieldError}; 5 | use crate::core::headers::{HeaderValue, Headers}; 6 | use crate::core::parser::params::parse_url_encoded; 7 | 8 | use crate::core::stream::Stream; 9 | 10 | pub type FormFields = HashMap>; 11 | 12 | pub struct UrlEncodedParser { 13 | stream: Arc, 14 | form_constraints: Arc, 15 | content_length: usize, 16 | } 17 | 18 | impl UrlEncodedParser { 19 | pub fn from( 20 | stream: Arc, 21 | headers: &Headers, 22 | form_constraints: Arc, 23 | ) -> Result { 24 | let content_length; 25 | if let Some(value) = headers.value("Content-Length") { 26 | content_length = match value.parse::() { 27 | Ok(value) => value, 28 | Err(_) => { 29 | return Err(FormFieldError::Others( 30 | None, 31 | "Invalid content length header.".to_owned(), 32 | false 33 | )); 34 | } 35 | } 36 | } else { 37 | return Err(FormFieldError::Others( 38 | None, 39 | "Content-Length header is missing.".to_owned(), 40 | false 41 | )); 42 | } 43 | 44 | Ok(UrlEncodedParser { 45 | stream, 46 | form_constraints, 47 | content_length, 48 | }) 49 | } 50 | 51 | /// 52 | /// Reads body from the stream equal to the `Content-Length` specified in the header, decodes 53 | /// url encoded raw body and returns the result. 54 | /// 55 | async fn read_query_params_from_stream(&self) -> Result { 56 | let max_body_size = self 57 | .form_constraints 58 | .max_body_size(self.stream.buffer_size().await); 59 | 60 | if self.content_length > max_body_size { 61 | return Err(FormFieldError::MaxBodySizeExceed); 62 | } 63 | 64 | let mut buffer = vec![]; 65 | 66 | loop { 67 | if buffer.len() >= self.content_length { 68 | let value = String::from_utf8_lossy(&buffer); 69 | return Ok(parse_url_encoded(value.to_string().as_str())); 70 | } 71 | 72 | let chunk = match self.stream.read_chunk().await { 73 | Ok(bytes) => bytes, 74 | Err(error) => { 75 | return Err(FormFieldError::Others(None, error.to_string(), true)); 76 | } 77 | }; 78 | buffer.extend(chunk); 79 | } 80 | } 81 | 82 | /// 83 | /// Returns parsing result for url encoded request body considering form constraints. 84 | /// 85 | pub async fn parse( 86 | stream: Arc, 87 | headers: &Headers, 88 | form_constraints: Arc, 89 | ) -> Result { 90 | let parser = UrlEncodedParser::from(stream, headers, form_constraints)?; 91 | let params = parser.read_query_params_from_stream().await?; 92 | Ok(params) 93 | } 94 | } 95 | 96 | #[cfg(test)] 97 | pub mod test { 98 | use std::collections::HashMap; 99 | use std::sync::Arc; 100 | 101 | use crate::core::forms::{FormConstraints, FormFieldError}; 102 | use crate::core::headers::{HeaderValue, Headers}; 103 | use crate::core::shortcuts::SingleText; 104 | use crate::core::stream::{AbstractStream, TestStreamWrapper}; 105 | 106 | use super::UrlEncodedParser; 107 | 108 | #[tokio::test()] 109 | async fn test_url_encode_parser() { 110 | let mut headers = Headers::new(); 111 | let test_data = b"name=John&location=ktm".to_vec(); 112 | headers.set("Content-Length", test_data.len().to_string()); 113 | 114 | let stream: Box = Box::new(TestStreamWrapper::new(test_data, 1024)); 115 | 116 | let form_constraints = Arc::new(FormConstraints::new( 117 | 2 * 1024 * 1024, 118 | 2 * 1024 * 1024, 119 | 500 * 1024 * 1024, 120 | 2 * 1024 * 1024, 121 | HashMap::new(), 122 | )); 123 | 124 | let url_encode_parser = 125 | UrlEncodedParser::parse(Arc::new(stream), &headers, form_constraints).await; 126 | assert_eq!(true, url_encode_parser.is_ok()); 127 | 128 | let parse_result = url_encode_parser.unwrap(); 129 | assert_eq!(Some(&"John".to_string()), parse_result.value("name")); 130 | assert_eq!(Some(&"ktm".to_string()), parse_result.value("location")); 131 | } 132 | 133 | #[tokio::test()] 134 | async fn test_no_content_length_parsing() { 135 | let headers = Headers::new(); 136 | let test_data = b"name=John&location=ktm".to_vec(); 137 | 138 | let stream: Box = Box::new(TestStreamWrapper::new(test_data, 1024)); 139 | 140 | let form_constraints = Arc::new(FormConstraints::new( 141 | 2 * 1024 * 1024, 142 | 2 * 1024 * 1024, 143 | 500 * 1024 * 1024, 144 | 2 * 1024 * 1024, 145 | HashMap::new(), 146 | )); 147 | 148 | let url_encode_parser = 149 | UrlEncodedParser::parse(Arc::new(stream), &headers, form_constraints).await; 150 | assert_eq!(true, url_encode_parser.is_err()); 151 | 152 | let form_field_error = url_encode_parser.unwrap_err(); 153 | match form_field_error { 154 | FormFieldError::Others(_, _, _) => { 155 | } 156 | _ => { 157 | assert!(true) 158 | } 159 | } 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /src/forms/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod fields; 2 | 3 | use std::collections::HashMap; 4 | use std::future::Future; 5 | use std::vec; 6 | 7 | use serde::{Deserialize, Serialize}; 8 | 9 | use crate::core::forms::FormFieldError; 10 | use crate::core::request::Request; 11 | 12 | use crate::forms::fields::AbstractFields; 13 | use crate::racoon_error; 14 | 15 | pub type FormFields = Vec>; 16 | 17 | #[derive(Debug, Serialize, Deserialize)] 18 | pub struct ValidationError { 19 | pub field_errors: HashMap>, 20 | pub others: Vec, 21 | #[serde(skip_serializing)] 22 | pub critical_errors: Vec, 23 | } 24 | 25 | pub trait FormValidator: Sized + Send { 26 | fn new() -> Self; 27 | fn form_fields(&mut self) -> FormFields; 28 | fn validate<'a>( 29 | mut self, 30 | request: &'a Request, 31 | ) -> Box> + Sync + Send + Unpin + 'a> 32 | where 33 | Self: 'a, 34 | Self: Sync, 35 | { 36 | let request = request.clone(); 37 | 38 | Box::new(Box::pin(async move { 39 | let mut field_errors: HashMap> = HashMap::new(); 40 | let mut other_errors: Vec = vec![]; 41 | let mut critical_errors: Vec = vec![]; 42 | 43 | let (mut form_data, mut files) = 44 | match request.parse_body(request.form_constraints.clone()).await { 45 | Ok((form_data, files)) => (form_data, files), 46 | Err(error) => { 47 | match error { 48 | FormFieldError::MaxBodySizeExceed => { 49 | other_errors.push("Max body size exceed.".to_string()); 50 | } 51 | 52 | FormFieldError::MaxHeaderSizeExceed => { 53 | other_errors.push("Max header size exceed.".to_string()); 54 | } 55 | 56 | FormFieldError::MaxFileSizeExceed(field_name) => { 57 | let file_size_exceed_error = 58 | vec!["Max file size exceed.".to_string()]; 59 | if let Some(errors) = field_errors.get_mut(&field_name) { 60 | errors.extend_from_slice(&file_size_exceed_error); 61 | } else { 62 | field_errors.insert(field_name, file_size_exceed_error); 63 | } 64 | } 65 | 66 | FormFieldError::MaxValueSizeExceed(field_name) => { 67 | let value_length_exceed_error = 68 | vec!["Max value length exceed.".to_string()]; 69 | if let Some(errors) = field_errors.get_mut(&field_name) { 70 | errors.extend_from_slice(&value_length_exceed_error); 71 | } else { 72 | field_errors.insert(field_name, value_length_exceed_error); 73 | } 74 | } 75 | 76 | FormFieldError::Others(field_name, error, is_critical) => { 77 | if !is_critical { 78 | // Safe to expose error to client 79 | if let Some(field_name) = field_name { 80 | field_errors.insert(field_name, vec![error]); 81 | } else { 82 | other_errors.push(error); 83 | } 84 | } else { 85 | // May contains system errors. Not safe to expose to client, 86 | racoon_error!("Critical error: {}", error); 87 | critical_errors.push(format!("Field: {}", error)); 88 | } 89 | } 90 | } 91 | 92 | let validation_error = ValidationError { 93 | field_errors, 94 | others: other_errors, 95 | critical_errors, 96 | }; 97 | return Err(validation_error); 98 | } 99 | }; 100 | 101 | for mut field in self.form_fields() { 102 | let field_name = field.field_name().await; 103 | 104 | let result; 105 | if let Some(custom_validate_result) = 106 | self.custom_validate(&request, &field_name, &field).await 107 | { 108 | result = custom_validate_result; 109 | } else { 110 | result = field.validate(&mut form_data, &mut files).await; 111 | } 112 | 113 | match result { 114 | Ok(()) => {} 115 | Err(error) => { 116 | field_errors.insert(field_name, error); 117 | } 118 | } 119 | } 120 | 121 | if field_errors.len() > 0 { 122 | let validation_error = ValidationError { 123 | field_errors, 124 | others: vec![], 125 | critical_errors, 126 | }; 127 | return Err(validation_error); 128 | } 129 | 130 | Ok(self) 131 | })) 132 | } 133 | 134 | fn custom_validate( 135 | &mut self, 136 | _: &Request, 137 | _: &String, 138 | _: &Box, 139 | ) -> Box>>> + Sync + Send + Unpin + 'static> 140 | { 141 | Box::new(Box::pin(async move { None })) 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/core/request/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::net::SocketAddr; 3 | use std::sync::atomic::{AtomicBool, Ordering}; 4 | use std::sync::Arc; 5 | 6 | use tokio::sync::Mutex; 7 | 8 | use crate::core::forms::{Files, FormConstraints, FormData}; 9 | 10 | use crate::core::headers::{HeaderValue, Headers}; 11 | use crate::core::parser::multipart::MultipartParser; 12 | use crate::core::parser::urlencoded::UrlEncodedParser; 13 | use crate::core::server::Context; 14 | use crate::core::stream::Stream; 15 | 16 | use crate::core::path::PathParams; 17 | use crate::{racoon_debug, racoon_error}; 18 | 19 | use crate::core::cookie::{parse_cookies_from_header, Cookies}; 20 | use crate::core::session::{Session, SessionManager}; 21 | use crate::core::shortcuts::SingleText; 22 | 23 | use super::forms::FormFieldError; 24 | 25 | pub type QueryParams = HashMap>; 26 | 27 | pub struct Request { 28 | pub stream: Arc, 29 | context: Arc, 30 | pub scheme: String, 31 | pub method: String, 32 | pub path: String, 33 | pub http_version: u8, 34 | pub headers: Headers, 35 | pub path_params: PathParams, 36 | pub query_params: QueryParams, 37 | pub cookies: Cookies, 38 | pub session: Session, 39 | pub body_read: Arc, 40 | pub form_constraints: Arc, 41 | pub response_headers: Arc>, 42 | } 43 | 44 | impl Request { 45 | pub async fn from( 46 | stream: Arc, 47 | context: Arc, 48 | scheme: String, 49 | method: String, 50 | path: String, 51 | http_version: u8, 52 | headers: Headers, 53 | path_params: PathParams, 54 | query_params: QueryParams, 55 | session_manager: Arc, 56 | body_read: Arc, 57 | form_constraints: Arc, 58 | response_headers: Arc>, 59 | ) -> Self { 60 | let cookies = parse_cookies_from_header(&headers); 61 | let session_id = cookies.value("sessionid"); 62 | 63 | let session = Session::from(session_manager, session_id, response_headers.clone()); 64 | 65 | Self { 66 | stream, 67 | context, 68 | scheme, 69 | method, 70 | path, 71 | http_version, 72 | headers, 73 | path_params, 74 | query_params, 75 | cookies, 76 | session, 77 | body_read, 78 | form_constraints, 79 | response_headers, 80 | } 81 | } 82 | 83 | pub async fn remote_addr(&self) -> Option { 84 | self.stream.peer_addr().await 85 | } 86 | 87 | pub fn context(&self) -> Option<&T> { 88 | self.context.downcast_ref::() 89 | } 90 | 91 | pub async fn parse(&self) -> (FormData, Files) { 92 | return match self.parse_body(self.form_constraints.clone()).await { 93 | Ok((form_data, files)) => (form_data, files), 94 | Err(_) => (FormData::new(), Files::new()), 95 | }; 96 | } 97 | 98 | pub async fn parse_body( 99 | &self, 100 | form_constraints: Arc, 101 | ) -> Result<(FormData, Files), FormFieldError> { 102 | let form_data = FormData::new(); 103 | let files = Files::new(); 104 | 105 | let content_type; 106 | if let Some(value) = self.headers.value("Content-Type") { 107 | content_type = value; 108 | } else { 109 | racoon_debug!("Content type is missing."); 110 | return Ok((form_data, files)); 111 | } 112 | 113 | let body_read = self.body_read.clone(); 114 | body_read.store(false, Ordering::Relaxed); 115 | 116 | if content_type 117 | .to_lowercase() 118 | .starts_with("multipart/form-data") 119 | { 120 | racoon_debug!("Parsing with MultipartParser"); 121 | 122 | return match MultipartParser::parse( 123 | self.stream.clone(), 124 | form_constraints, 125 | &self.headers, 126 | ) 127 | .await 128 | { 129 | Ok((form_data, files)) => { 130 | self.body_read.store(true, Ordering::Relaxed); 131 | Ok((form_data, files)) 132 | } 133 | Err(error) => { 134 | racoon_error!("Error while parsing multipart body: {:?}", error); 135 | Err(error) 136 | } 137 | }; 138 | } else if content_type 139 | .to_lowercase() 140 | .starts_with("application/x-www-form-urlencoded") 141 | { 142 | racoon_debug!("Parsing with UrlEncoded parser."); 143 | 144 | return match UrlEncodedParser::parse( 145 | self.stream.clone(), 146 | &self.headers, 147 | form_constraints, 148 | ) 149 | .await 150 | { 151 | Ok(form_data) => { 152 | self.body_read.store(true, Ordering::Relaxed); 153 | Ok((form_data, files)) 154 | } 155 | Err(error) => { 156 | racoon_error!("Error while parsing x-www-urlencoded form. {:?}", error); 157 | Err(error) 158 | } 159 | }; 160 | } 161 | 162 | racoon_debug!("Unhandled enctype: {}", content_type); 163 | Ok((form_data, files)) 164 | } 165 | } 166 | 167 | impl Clone for Request { 168 | fn clone(&self) -> Self { 169 | Self { 170 | stream: self.stream.clone(), 171 | context: self.context.clone(), 172 | scheme: self.scheme.clone(), 173 | method: self.method.clone(), 174 | path: self.path.clone(), 175 | http_version: self.http_version.clone(), 176 | headers: self.headers.clone(), 177 | path_params: self.path_params.clone(), 178 | query_params: self.query_params.clone(), 179 | cookies: self.cookies.clone(), 180 | session: self.session.clone(), 181 | body_read: self.body_read.clone(), 182 | form_constraints: self.form_constraints.clone(), 183 | response_headers: self.response_headers.clone(), 184 | } 185 | } 186 | } 187 | 188 | #[derive(Debug)] 189 | pub enum RequestError { 190 | HeaderSizeExceed, 191 | Others(String), 192 | } 193 | -------------------------------------------------------------------------------- /src/core/session/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod managers; 2 | 3 | use std::future::Future; 4 | use std::sync::Arc; 5 | use std::time::Duration; 6 | 7 | use tokio::sync::Mutex; 8 | use uuid::Uuid; 9 | 10 | use crate::core::headers::Headers; 11 | 12 | use super::cookie; 13 | 14 | pub type SessionResult = Box + Send + Unpin>; 15 | 16 | pub trait AbstractSessionManager: Sync + Send { 17 | /// Set or update session value of the client. 18 | fn set( 19 | &self, 20 | session_id: &String, 21 | name: &str, 22 | value: &str, 23 | ) -> SessionResult>; 24 | 25 | /// Returns session value of the client. 26 | fn get(&self, session_id: &String, name: &str) -> SessionResult>; 27 | 28 | /// Removes session key and value of the client. 29 | fn remove(&self, session_id: &String, name: &str) -> SessionResult>; 30 | 31 | /// Removes all session key and value of the client. 32 | fn destroy(&self, session_id: &String) -> SessionResult>; 33 | } 34 | 35 | pub type SessionManager = Box; 36 | 37 | pub struct Session { 38 | session_manager: Arc, 39 | session_id: Arc>>, 40 | response_headers: Arc>, 41 | } 42 | 43 | impl Clone for Session { 44 | fn clone(&self) -> Self { 45 | Self { 46 | session_manager: self.session_manager.clone(), 47 | session_id: self.session_id.clone(), 48 | response_headers: self.response_headers.clone(), 49 | } 50 | } 51 | } 52 | 53 | impl Session { 54 | pub fn from( 55 | session_manager: Arc, 56 | session_id: Option<&String>, 57 | response_headers: Arc>, 58 | ) -> Self { 59 | let session_id_value; 60 | 61 | if let Some(session_id) = session_id { 62 | session_id_value = Some(session_id.to_owned()); 63 | } else { 64 | session_id_value = None; 65 | } 66 | 67 | Self { 68 | session_manager, 69 | session_id: Arc::new(Mutex::new(session_id_value)), 70 | response_headers: response_headers.clone(), 71 | } 72 | } 73 | 74 | /// 75 | /// Session id of the client received from the cookie header `sessionid`. The request instance automatically initializes 76 | /// with new value if the `sessionid` header is not present. 77 | /// 78 | pub async fn session_id(&self) -> Option { 79 | let session_id_lock = self.session_id.lock().await; 80 | 81 | if let Some(session_id) = &*session_id_lock { 82 | return Some(session_id.to_owned()); 83 | } 84 | 85 | None 86 | } 87 | 88 | /// 89 | /// Set or update exisiting session value. 90 | /// 91 | /// # Examples 92 | /// ``` 93 | /// use racoon::core::request::Request; 94 | /// 95 | /// async fn home(request: Request) { 96 | /// let session = request.session; 97 | /// let _ = session.set("name", "John").await; 98 | /// } 99 | /// ``` 100 | /// 101 | pub async fn set>(&self, name: S, value: S) -> std::io::Result<()> { 102 | // If sessionid was not present in cookie, puts additional Set-Cookie header in the 103 | // response. 104 | 105 | let mut session_id_lock = self.session_id.lock().await; 106 | let session_id; 107 | 108 | if !session_id_lock.is_some() { 109 | // Lazily creates sessionid when set method is called. 110 | session_id = Uuid::new_v4().to_string(); 111 | 112 | let mut response_headers = self.response_headers.lock().await; 113 | cookie::set_cookie( 114 | &mut response_headers, 115 | "sessionid", 116 | &session_id, 117 | Duration::from_secs(7 * 86400), 118 | ); 119 | 120 | *session_id_lock = Some(session_id); 121 | } 122 | 123 | if let Some(session_id) = &*session_id_lock { 124 | match self 125 | .session_manager 126 | .set(session_id, name.as_ref(), value.as_ref()) 127 | .await 128 | { 129 | Ok(()) => return Ok(()), 130 | Err(error) => { 131 | return Err(std::io::Error::other(error)); 132 | } 133 | }; 134 | } 135 | 136 | Ok(()) 137 | } 138 | 139 | /// 140 | /// Returns session value of type `Option`. 141 | /// 142 | /// # Examples 143 | /// ``` 144 | /// use racoon::core::request::Request; 145 | /// 146 | /// async fn home(request: Request) { 147 | /// let session = request.session; 148 | /// let name = session.get("name").await; 149 | /// } 150 | /// ``` 151 | /// 152 | /// This method does not return or print any error message by default. 153 | /// ``` 154 | /// use racoon::core::server::Server; 155 | /// 156 | /// // Enable debugging 157 | /// Server::enable_logging(); 158 | /// ``` 159 | /// 160 | pub async fn get>(&self, name: S) -> Option { 161 | let session_id_lock = self.session_id.lock().await; 162 | 163 | if let Some(session_id) = &*session_id_lock { 164 | return self.session_manager.get(session_id, name.as_ref()).await; 165 | } 166 | 167 | None 168 | } 169 | 170 | /// 171 | /// Removes session value. 172 | /// 173 | /// # Examples 174 | /// ``` 175 | /// use racoon::core::request::Request; 176 | /// 177 | /// async fn home(request: Request) { 178 | /// let session = request.session; 179 | /// let _ = session.remove("name").await; 180 | /// } 181 | /// ``` 182 | /// 183 | pub async fn remove>(&self, name: S) -> std::io::Result<()> { 184 | let session_id_lock = self.session_id.lock().await; 185 | 186 | if let Some(session_id) = &*session_id_lock { 187 | return self.session_manager.remove(session_id, name.as_ref()).await; 188 | } 189 | 190 | Ok(()) 191 | } 192 | 193 | /// 194 | /// Removes all session values of the client. 195 | /// 196 | pub async fn destroy(&self) -> std::io::Result<()> { 197 | // Removes sesisonid from Cookie 198 | let response_headers_ref = self.response_headers.clone(); 199 | let mut response_headers = response_headers_ref.lock().await; 200 | 201 | let expire_header_value = format!( 202 | "{}=;Expires=Sun, 06 Nov 1994 08:49:37 GMT; Path=/", 203 | "sessionid" 204 | ); 205 | response_headers.insert( 206 | "Set-Cookie".to_string(), 207 | vec![expire_header_value.as_bytes().to_vec()], 208 | ); 209 | 210 | let session_lock = self.session_id.lock().await; 211 | if let Some(session_id) = &*session_lock { 212 | return self.session_manager.destroy(session_id).await; 213 | } 214 | 215 | Ok(()) 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /src/core/response/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod status; 2 | 3 | use std::collections::HashMap; 4 | use std::time::Duration; 5 | 6 | use serde_json::json; 7 | 8 | use crate::core::cookie; 9 | use crate::core::headers::{HeaderValue, Headers}; 10 | use crate::core::response::status::ResponseStatus; 11 | 12 | pub trait AbstractResponse: Send { 13 | fn status(&self) -> (u32, String); 14 | fn serve_default(&mut self) -> bool; 15 | fn get_headers(&mut self) -> &mut Headers; 16 | fn get_body(&mut self) -> &mut Vec; 17 | fn should_close(&mut self) -> bool; 18 | } 19 | 20 | pub type Response = Box; 21 | 22 | pub struct HttpResponse { 23 | status_code: u32, 24 | status_text: String, 25 | headers: Headers, 26 | body: Vec, 27 | keep_alive: bool, 28 | serve_default: bool, 29 | } 30 | 31 | impl AbstractResponse for HttpResponse { 32 | fn status(&self) -> (u32, String) { 33 | (self.status_code, self.status_text.to_owned()) 34 | } 35 | 36 | fn serve_default(&mut self) -> bool { 37 | self.serve_default 38 | } 39 | 40 | fn get_headers(&mut self) -> &mut Headers { 41 | &mut self.headers 42 | } 43 | 44 | fn get_body(&mut self) -> &mut Vec { 45 | &mut self.body 46 | } 47 | 48 | fn should_close(&mut self) -> bool { 49 | !self.keep_alive 50 | } 51 | } 52 | 53 | impl HttpResponse { 54 | pub fn content_type(mut self, value: &str) -> Self { 55 | self.headers.set("Content-Type", value.as_bytes()); 56 | self 57 | } 58 | 59 | pub fn keep_alive(mut self, is_alive: bool) -> Self { 60 | self.keep_alive = !is_alive; 61 | self 62 | } 63 | 64 | pub fn disable_serve_default(mut self) -> Self { 65 | self.serve_default = false; 66 | self 67 | } 68 | 69 | pub fn location(mut self, url: &str) -> Box { 70 | self.get_headers().set("Location", url); 71 | Box::new(self) 72 | } 73 | 74 | pub fn body>(mut self, data: S) -> Box { 75 | let data = data.as_ref(); 76 | 77 | self.headers 78 | .set("Content-Length", data.len().to_string()); 79 | 80 | self.headers.set("Content-Type", "text/html"); 81 | 82 | if self.headers.value("Connection").is_none() { 83 | if self.keep_alive { 84 | self.headers.set("Connection", "keep-alive"); 85 | } else { 86 | self.headers.set("Connection", "close"); 87 | } 88 | } 89 | 90 | self.body = data.as_bytes().to_vec(); 91 | 92 | Box::new(self) 93 | } 94 | 95 | pub fn empty(self) -> Box { 96 | self.body("") 97 | } 98 | 99 | pub fn set_cookie>(&mut self, name: S, value: S, max_age: Duration) { 100 | let headers = self.get_headers(); 101 | cookie::set_cookie(headers, name, value, max_age); 102 | } 103 | 104 | pub fn remove_cookie>(&mut self, name: S) { 105 | let headers = &mut self.headers; 106 | let expire_header_value = format!( 107 | "{}=;Expires=Sun, 06 Nov 1994 08:49:37 GMT; Path=/", 108 | name.as_ref() 109 | ); 110 | headers.set_multiple("Set-Cookie", expire_header_value); 111 | } 112 | } 113 | 114 | impl ResponseStatus for HttpResponse { 115 | fn with_status(status_code: u32, status_text: &str) -> Self { 116 | Self { 117 | status_code, 118 | status_text: status_text.to_owned(), 119 | headers: HashMap::new(), 120 | body: vec![], 121 | keep_alive: true, 122 | serve_default: true, 123 | } 124 | } 125 | } 126 | 127 | pub fn response_to_bytes(response: &mut Box) -> Vec { 128 | let mut response_bytes: Vec = Vec::with_capacity(response.get_body().len()); 129 | let (status_code, status_text) = response.status(); 130 | 131 | // Append header response start line 132 | let response_header_begin = format!("HTTP/1.1 {} {}\r\n", status_code, status_text); 133 | response_bytes.extend(response_header_begin.as_bytes()); 134 | 135 | // Append headers 136 | response.get_headers().iter().for_each(|(name, values)| { 137 | for value in values { 138 | response_bytes.extend(name.as_bytes()); 139 | response_bytes.extend(b": "); 140 | response_bytes.extend(value); 141 | response_bytes.extend(b"\r\n"); 142 | } 143 | }); 144 | 145 | response_bytes.extend(b"\r\n"); 146 | 147 | // Body start 148 | response_bytes.extend(response.get_body().as_slice()); 149 | response_bytes 150 | } 151 | 152 | pub struct JsonResponse { 153 | http_response: HttpResponse, 154 | } 155 | 156 | impl JsonResponse { 157 | pub fn body(mut self, json: serde_json::Value) -> Box { 158 | let json_text = json.to_string(); 159 | 160 | self.http_response 161 | .headers 162 | .set("Content-Length", json_text.len().to_string().as_bytes()); 163 | 164 | if self.http_response.headers.value("Connection").is_none() { 165 | if self.http_response.keep_alive { 166 | self.http_response 167 | .headers 168 | .set("Connection", "keep-alive"); 169 | } else { 170 | self.http_response 171 | .headers 172 | .set("Connection", "close"); 173 | } 174 | } 175 | 176 | self.http_response.body = json_text.as_bytes().to_vec(); 177 | Box::new(self) 178 | } 179 | 180 | /// 181 | /// Creates empty JSON object response. 182 | /// 183 | pub fn empty(self) -> Box { 184 | self.body(json!({})) 185 | } 186 | 187 | /// 188 | /// Sets cookie in max age from "/" path. 189 | /// 190 | pub fn set_cookie>(&mut self, name: S, value: S, max_age: Duration) { 191 | self.http_response.set_cookie(name, value, max_age); 192 | } 193 | 194 | /// 195 | /// Removes cookie from "/" path. 196 | /// 197 | pub fn remove_cookie>(&mut self, name: S) { 198 | self.http_response.remove_cookie(name) 199 | } 200 | } 201 | 202 | impl AbstractResponse for JsonResponse { 203 | fn status(&self) -> (u32, String) { 204 | self.http_response.status() 205 | } 206 | 207 | fn serve_default(&mut self) -> bool { 208 | self.http_response.serve_default 209 | } 210 | 211 | fn get_headers(&mut self) -> &mut Headers { 212 | self.http_response.get_headers() 213 | } 214 | 215 | fn get_body(&mut self) -> &mut Vec { 216 | self.http_response.get_body() 217 | } 218 | 219 | fn should_close(&mut self) -> bool { 220 | self.http_response.should_close() 221 | } 222 | } 223 | 224 | impl ResponseStatus for JsonResponse { 225 | fn with_status(status_code: u32, status_text: &str) -> Self { 226 | let mut http_response = HttpResponse::with_status(status_code, status_text); 227 | let headers = http_response.get_headers(); 228 | headers.set("Content-Type", "application/json"); 229 | 230 | Self { http_response } 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /src/core/parser/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod multipart; 2 | pub mod urlencoded; 3 | 4 | pub mod headers { 5 | use std::collections::HashMap; 6 | use std::sync::Arc; 7 | 8 | use crate::core::headers::{Headers, HeaderValue}; 9 | use crate::core::request::RequestError; 10 | use crate::core::server::RequestConstraints; 11 | use crate::core::stream::Stream; 12 | 13 | #[derive(Debug)] 14 | pub struct RequestHeaderResult { 15 | pub method: Option, 16 | pub http_version: Option, 17 | pub raw_path: Option, 18 | pub headers: Headers, 19 | } 20 | 21 | impl RequestHeaderResult { 22 | pub fn new() -> Self { 23 | Self { 24 | method: None, 25 | http_version: None, 26 | raw_path: None, 27 | headers: HashMap::new(), 28 | } 29 | } 30 | } 31 | 32 | pub async fn read_request_headers(stream: Arc, 33 | request_constraints: Arc) 34 | -> Result { 35 | let max_request_header_size = request_constraints.max_request_header_size(stream.buffer_size().await); 36 | 37 | let mut buffer: Vec = vec![]; 38 | 39 | let mut bytes_read = 0; 40 | 41 | loop { 42 | let chunk = match stream.read_chunk().await { 43 | Ok(bytes) => bytes, 44 | Err(error) => { 45 | return Err(RequestError::Others(error.to_string())); 46 | } 47 | }; 48 | bytes_read += chunk.len(); 49 | buffer.extend(chunk); 50 | 51 | if bytes_read > max_request_header_size { 52 | return Err(RequestError::HeaderSizeExceed); 53 | } 54 | 55 | let mut headers = vec![httparse::EMPTY_HEADER; request_constraints.max_header_count]; 56 | let mut request = httparse::Request::new(&mut headers); 57 | let result = request.parse(&buffer); 58 | 59 | match result { 60 | Ok(status) => { 61 | if status.is_partial() { 62 | continue; 63 | } 64 | 65 | let matched_position = status.unwrap(); 66 | let partial_body = &buffer[matched_position..]; 67 | let _ = stream.restore_payload(partial_body).await; 68 | 69 | let request_method; 70 | if let Some(method) = request.method { 71 | request_method = Some(method.to_string()); 72 | } else { 73 | request_method = None; 74 | } 75 | 76 | let http_version; 77 | if let Some(version) = request.version { 78 | http_version = Some(version); 79 | } else { 80 | http_version = None; 81 | } 82 | 83 | let path; 84 | if let Some(request_path) = request.path { 85 | path = Some(request_path.to_owned()); 86 | } else { 87 | path = None; 88 | } 89 | 90 | let mut headers = HashMap::new(); 91 | request.headers.iter().for_each(|header| { 92 | headers.set_multiple(header.name, header.value); 93 | }); 94 | 95 | if status.is_complete() { 96 | return Ok(RequestHeaderResult { 97 | method: request_method, 98 | http_version, 99 | raw_path: path, 100 | headers, 101 | }); 102 | } 103 | } 104 | Err(_) => { 105 | // Not actual error 106 | // Wait until header is not completely found 107 | } 108 | } 109 | } 110 | } 111 | } 112 | 113 | 114 | pub mod path { 115 | /// 116 | /// Does not include `?` character in raw query. 117 | /// 118 | pub fn path_and_raw_query>(raw_path: S) -> (String, String) { 119 | let raw_path = raw_path.as_ref().to_string(); 120 | let split: Vec<&str> = raw_path.splitn(2, "?").collect(); 121 | 122 | let path; 123 | if let Some(value) = split.get(0) { 124 | path = value.to_string(); 125 | } else { 126 | path = raw_path.to_owned(); 127 | } 128 | 129 | let raw_query; 130 | if let Some(value) = split.get(1) { 131 | raw_query = value.to_string(); 132 | } else { 133 | raw_query = "".to_owned(); 134 | } 135 | 136 | return (path, raw_query); 137 | } 138 | } 139 | 140 | pub mod params { 141 | use std::collections::HashMap; 142 | use crate::core::parser::path::path_and_raw_query; 143 | 144 | /// 145 | /// # Examples 146 | /// ``` 147 | /// use racoon::core::shortcuts::SingleText; 148 | /// use racoon::core::parser::params::query_params_from_raw; 149 | /// 150 | /// let raw_path = "?name=John&location=ktm"; 151 | /// let query_params = query_params_from_raw(raw_path); 152 | /// 153 | /// let name = query_params.value("name"); 154 | /// let location = query_params.value("location"); 155 | /// let unknown = query_params.value("unknown"); 156 | /// assert_eq!(name, Some(&"John".to_string())); 157 | /// assert_eq!(location, Some(&"ktm".to_string())); 158 | /// assert_eq!(unknown, None); 159 | /// 160 | /// ``` 161 | /// 162 | pub fn query_params_from_raw>(raw_path: S) -> HashMap> { 163 | let (_, raw_query) = path_and_raw_query(raw_path.as_ref()); 164 | parse_url_encoded(&raw_query) 165 | } 166 | 167 | pub fn parse_url_encoded>(text: S) -> HashMap> { 168 | let text = text.as_ref(); 169 | let mut params = HashMap::new(); 170 | if text.len() == 0 { 171 | return params; 172 | } 173 | 174 | let values = text.split("&"); 175 | 176 | for value in values { 177 | let key_values: Vec<&str> = value.split("=").collect(); 178 | if key_values.len() >= 2 { 179 | let name = key_values.get(0).unwrap(); 180 | let value = key_values.get(1).unwrap(); 181 | 182 | let name_formatted = match urlencoding::decode(name) { 183 | Ok(value) => value.to_string(), 184 | Err(_) => name.to_string() 185 | }; 186 | 187 | let value_formatted = match urlencoding::decode(value) { 188 | Ok(value) => value.to_string(), 189 | Err(_) => value.to_string() 190 | }; 191 | 192 | if !params.contains_key(&name_formatted) { 193 | params.insert(name.to_string(), Vec::new()); 194 | } 195 | 196 | let values = params.get_mut(&name_formatted).unwrap(); 197 | values.push(value_formatted); 198 | } 199 | } 200 | return params; 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /src/core/response/status.rs: -------------------------------------------------------------------------------- 1 | /// 2 | /// More information: 3 | /// 4 | pub trait ResponseStatus: Sized { 5 | fn with_status(status_code: u32, status_text: &str) -> Self; 6 | 7 | fn r#continue() -> Self { 8 | Self::with_status(100, "Continue") 9 | } 10 | 11 | fn switching_protocols() -> Self { 12 | Self::with_status(101, "Switching Protocols") 13 | } 14 | 15 | fn processing() -> Self { 16 | Self::with_status(102, "Processing") 17 | } 18 | 19 | fn early_hints() -> Self { 20 | Self::with_status(103, "Early Hints") 21 | } 22 | 23 | fn ok() -> Self { 24 | Self::with_status(200, "OK") 25 | } 26 | 27 | fn created() -> Self { 28 | Self::with_status(201, "Created") 29 | } 30 | 31 | fn accepted() -> Self { 32 | Self::with_status(202, "Accepted") 33 | } 34 | 35 | fn non_authoritative_information() -> Self { 36 | Self::with_status(203, "Non-Authoritative Information") 37 | } 38 | 39 | fn no_content() -> Self { 40 | Self::with_status(204, "No Content") 41 | } 42 | 43 | fn reset_content() -> Self { 44 | Self::with_status(205, "Reset Content") 45 | } 46 | 47 | fn partial_content() -> Self { 48 | Self::with_status(206, "Partial Content") 49 | } 50 | 51 | fn multi_status() -> Self { 52 | Self::with_status(207, "Multi-Status") 53 | } 54 | 55 | fn already_reported() -> Self { 56 | Self::with_status(208, "Already Reported") 57 | } 58 | 59 | fn im_used() -> Self { 60 | Self::with_status(226, "IM Used") 61 | } 62 | 63 | fn multiple_choices() -> Self { 64 | Self::with_status(300, "Multiple Choices") 65 | } 66 | 67 | fn moved_permanently() -> Self { 68 | Self::with_status(301, "Moved Permanently") 69 | } 70 | 71 | fn found() -> Self { 72 | Self::with_status(302, "Found") 73 | } 74 | 75 | fn see_other() -> Self { 76 | Self::with_status(303, "See Other") 77 | } 78 | 79 | fn not_modified() -> Self { 80 | Self::with_status(304, "Not Modified") 81 | } 82 | 83 | /// 84 | /// Depreciated 85 | /// 86 | fn use_proxy() -> Self { 87 | Self::with_status(305, "Use Proxy") 88 | } 89 | 90 | /// 91 | /// Depreciated 92 | /// 93 | fn unused() -> Self { 94 | Self::with_status(306, "Unused") 95 | } 96 | 97 | fn temporary_redirect() -> Self { 98 | Self::with_status(307, "Temporary Redirect") 99 | } 100 | 101 | fn permanent_redirect() -> Self { 102 | Self::with_status(308, "Permanent Redirect") 103 | } 104 | 105 | fn bad_request() -> Self { 106 | Self::with_status(400, "Bad Request") 107 | } 108 | 109 | fn unauthorized() -> Self { 110 | Self::with_status(401, "Unauthorized") 111 | } 112 | 113 | /// 114 | /// Experimental. Expect behaviour to change in the future. 115 | /// 116 | fn payment_required() -> Self { 117 | Self::with_status(403, "Payment Required") 118 | } 119 | 120 | fn forbidden() -> Self { 121 | Self::with_status(403, "Forbidden") 122 | } 123 | 124 | fn not_found() -> Self { 125 | Self::with_status(404, "Not Found") 126 | } 127 | 128 | fn method_not_allowed() -> Self { 129 | Self::with_status(405, "Method Not Allowed") 130 | } 131 | 132 | fn not_acceptable() -> Self { 133 | Self::with_status(406, "Not Acceptable") 134 | } 135 | 136 | fn proxy_authentication_required() -> Self { 137 | Self::with_status(407, "Proxy Authentication Required") 138 | } 139 | 140 | fn request_timeout() -> Self { 141 | Self::with_status(408, "Request Timeout") 142 | } 143 | 144 | fn conflict() -> Self { 145 | Self::with_status(409, "Conflict") 146 | } 147 | 148 | fn gone() -> Self { 149 | Self::with_status(410, "Gone") 150 | } 151 | 152 | fn length_required() -> Self { 153 | Self::with_status(411, "Length Required") 154 | } 155 | 156 | fn precondition_failed() -> Self { 157 | Self::with_status(412, "Precondition Failed") 158 | } 159 | 160 | fn payload_too_large() -> Self { 161 | Self::with_status(412, "Payload Too Large") 162 | } 163 | 164 | fn uri_too_long() -> Self { 165 | Self::with_status(414, "URI Too Long") 166 | } 167 | 168 | fn unsupported_media_type() -> Self { 169 | Self::with_status(415, "Unsupported Media Type") 170 | } 171 | 172 | fn range_not_satisfiable() -> Self { 173 | Self::with_status(416, "Range Not Satisfiable") 174 | } 175 | 176 | fn expectation_failed() -> Self { 177 | Self::with_status(417, "Expectation Failed") 178 | } 179 | 180 | fn im_a_teapot() -> Self { 181 | Self::with_status(418, "I',m a teapot") 182 | } 183 | 184 | fn misdirected_request() -> Self { 185 | Self::with_status(421, "Misdirected Request") 186 | } 187 | 188 | fn unprocessable_content() -> Self { 189 | Self::with_status(422, "Unprocessable Content") 190 | } 191 | 192 | fn locked() -> Self { 193 | Self::with_status(423, "Locked") 194 | } 195 | 196 | fn failed_dependency() -> Self { 197 | Self::with_status(424, "Failed Dependency") 198 | } 199 | 200 | 201 | /// 202 | /// Experimental. Expect behaviour to change in the future. 203 | /// 204 | fn too_early() -> Self { 205 | Self::with_status(425, "Too Early") 206 | } 207 | 208 | fn upgrade_required() -> Self { 209 | Self::with_status(426, "Upgrade Required") 210 | } 211 | 212 | fn precondition_required() -> Self { 213 | Self::with_status(428, "Precondition Required") 214 | } 215 | 216 | fn too_many_requests() -> Self { 217 | Self::with_status(429, "Too Many Requests") 218 | } 219 | 220 | fn request_header_fields_too_large() -> Self { 221 | Self::with_status(431, "Request Header Fields Too Large") 222 | } 223 | 224 | fn unavailable_for_legal_reasons() -> Self { 225 | Self::with_status(451, "Unavailable For Legal Reasons") 226 | } 227 | 228 | fn internal_server_error() -> Self { 229 | Self::with_status(500, "Internal Server Error") 230 | } 231 | 232 | fn not_implemented() -> Self { 233 | Self::with_status(501, "Not Implemented") 234 | } 235 | 236 | fn bad_gateway() -> Self { 237 | Self::with_status(502, "Bad Gateway") 238 | } 239 | 240 | fn service_unavailable() -> Self { 241 | Self::with_status(503, "Service Unavailable") 242 | } 243 | 244 | fn gateway_timeout() -> Self { 245 | Self::with_status(504, "Gateway Timeout") 246 | } 247 | 248 | fn http_version_not_supported() -> Self { 249 | Self::with_status(505, "HTTP Version Not Supported") 250 | } 251 | 252 | fn variant_also_negotiates() -> Self { 253 | Self::with_status(506, "Variant Also Negotiates") 254 | } 255 | 256 | fn insufficient_storage() -> Self { 257 | Self::with_status(507, "Insufficient Storage") 258 | } 259 | 260 | fn loop_detected() -> Self { 261 | Self::with_status(508, "Loop Detected") 262 | } 263 | 264 | fn not_extended() -> Self { 265 | Self::with_status(510, "Not Extended") 266 | } 267 | 268 | fn network_authentication_required() -> Self { 269 | Self::with_status(511, "Network Authentication Required") 270 | } 271 | } -------------------------------------------------------------------------------- /src/forms/fields/uuid_field.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::marker::PhantomData; 3 | use std::sync::atomic::{AtomicBool, Ordering}; 4 | use std::sync::Arc; 5 | 6 | use tokio::sync::Mutex; 7 | 8 | use uuid::Uuid; 9 | 10 | use crate::core::forms::{Files, FormData}; 11 | use crate::forms::fields::{AbstractFields, FieldResult}; 12 | 13 | pub trait ToTypeT { 14 | fn from_vec(values: &mut Vec) -> Option 15 | where 16 | Self: Sized; 17 | 18 | fn is_optional() -> bool; 19 | } 20 | 21 | impl ToTypeT for Uuid { 22 | fn from_vec(values: &mut Vec) -> Option 23 | where 24 | Self: Sized, 25 | { 26 | if values.len() > 0 { 27 | let value = values.remove(0); 28 | match Uuid::parse_str(&value) { 29 | Ok(uuid) => return Some(uuid), 30 | _ => {} 31 | } 32 | } 33 | 34 | None 35 | } 36 | 37 | fn is_optional() -> bool { 38 | false 39 | } 40 | } 41 | 42 | impl ToTypeT for Option { 43 | fn from_vec(values: &mut Vec) -> Option 44 | where 45 | Self: Sized, 46 | { 47 | if values.len() > 0 { 48 | let value = values.remove(0); 49 | match Uuid::parse_str(&value) { 50 | Ok(uuid) => return Some(Some(uuid)), 51 | _ => {} 52 | } 53 | } 54 | 55 | // Outer Some denotes conversion success with value None. 56 | Some(None) 57 | } 58 | 59 | fn is_optional() -> bool { 60 | true 61 | } 62 | } 63 | 64 | impl ToTypeT for Vec { 65 | fn from_vec(values: &mut Vec) -> Option 66 | where 67 | Self: Sized, 68 | { 69 | let mut uuids = vec![]; 70 | if values.len() == 0 { 71 | return None; 72 | } 73 | 74 | for i in (0..values.len()).rev() { 75 | let value = values.remove(i); 76 | match Uuid::parse_str(&value) { 77 | Ok(value) => { 78 | uuids.insert(0, value); 79 | } 80 | _ => { 81 | // Return conversion failed. Invalid UUID found. 82 | return None; 83 | } 84 | } 85 | } 86 | 87 | Some(uuids) 88 | } 89 | 90 | fn is_optional() -> bool { 91 | false 92 | } 93 | } 94 | 95 | impl ToTypeT for Option> { 96 | fn from_vec(values: &mut Vec) -> Option 97 | where 98 | Self: Sized, 99 | { 100 | let mut uuids = vec![]; 101 | if values.len() == 0 { 102 | return Some(None); 103 | } 104 | 105 | for i in (0..values.len()).rev() { 106 | let value = values.remove(i); 107 | match Uuid::parse_str(&value) { 108 | Ok(value) => { 109 | uuids.insert(0, value); 110 | } 111 | _ => { 112 | // Return conversion failed. Invalid UUID found. 113 | return Some(None); 114 | } 115 | } 116 | } 117 | 118 | Some(Some(uuids)) 119 | } 120 | 121 | fn is_optional() -> bool { 122 | true 123 | } 124 | } 125 | 126 | type BoxResult = Box; 127 | 128 | pub enum UuidFieldError<'a> { 129 | /// (field_name) 130 | MissingField(&'a String), 131 | /// (field_name, values) 132 | InvalidUuid(&'a String, &'a Vec), 133 | } 134 | 135 | pub type ErrorHandler = Box) -> Vec>; 136 | 137 | pub struct UuidField { 138 | field_name: String, 139 | result: Arc>>, 140 | validated: Arc, 141 | error_handler: Option>, 142 | phantom: PhantomData, 143 | } 144 | 145 | impl Clone for UuidField { 146 | fn clone(&self) -> Self { 147 | Self { 148 | field_name: self.field_name.clone(), 149 | result: self.result.clone(), 150 | validated: self.validated.clone(), 151 | error_handler: self.error_handler.clone(), 152 | phantom: self.phantom.clone(), 153 | } 154 | } 155 | } 156 | 157 | impl UuidField { 158 | pub fn new>(field_name: S) -> Self { 159 | let field_name = field_name.as_ref().to_string(); 160 | 161 | Self { 162 | field_name, 163 | result: Arc::new(Mutex::new(None)), 164 | validated: Arc::new(AtomicBool::new(false)), 165 | error_handler: None, 166 | phantom: PhantomData, 167 | } 168 | } 169 | 170 | pub fn handle_error_message( 171 | mut self, 172 | callback: fn(UuidFieldError, Vec) -> Vec, 173 | ) -> Self { 174 | self.error_handler = Some(Arc::new(Box::new(callback))); 175 | self 176 | } 177 | 178 | pub async fn value(self) -> T 179 | where 180 | T: 'static, 181 | { 182 | if !self.validated.load(Ordering::Relaxed) { 183 | panic!("This field is not validated. Please call form.validate() method before accessing value."); 184 | } 185 | 186 | let mut lock = self.result.lock().await; 187 | if let Some(result) = lock.take() { 188 | match result.downcast::() { 189 | Ok(t) => { 190 | return *t; 191 | } 192 | _ => {} 193 | }; 194 | } 195 | panic!("Unexpected error. Bug in uuid_field.rs file."); 196 | } 197 | } 198 | 199 | impl AbstractFields for UuidField { 200 | fn field_name(&self) -> FieldResult { 201 | let field_name = self.field_name.clone(); 202 | Box::new(Box::pin(async move { field_name })) 203 | } 204 | 205 | fn validate( 206 | &mut self, 207 | form_data: &mut FormData, 208 | _: &mut Files, 209 | ) -> FieldResult>> { 210 | let field_name = self.field_name.clone(); 211 | let mut values = form_data.remove(&field_name); 212 | let result_ref = self.result.clone(); 213 | let validated = self.validated.clone(); 214 | 215 | let error_handler = self.error_handler.clone(); 216 | 217 | Box::new(Box::pin(async move { 218 | let is_empty; 219 | let is_optional = T::is_optional(); 220 | 221 | let mut errors: Vec = vec![]; 222 | 223 | if let Some(mut values) = values.as_mut() { 224 | is_empty = values.is_empty(); 225 | let option_t = T::from_vec(&mut values); 226 | 227 | if let Some(t) = option_t { 228 | let result_ref = result_ref.clone(); 229 | let mut result = result_ref.lock().await; 230 | *result = Some(Box::new(t)); 231 | } else { 232 | let default_uuid_invalid_error = "Invalid UUId.".to_string(); 233 | if let Some(error_handler) = error_handler.clone() { 234 | let invalid_uuid_error = UuidFieldError::InvalidUuid(&field_name, &values); 235 | let custom_errors = 236 | error_handler(invalid_uuid_error, vec![default_uuid_invalid_error]); 237 | errors.extend_from_slice(&custom_errors); 238 | } else { 239 | errors.push(default_uuid_invalid_error); 240 | } 241 | } 242 | } else { 243 | is_empty = true; 244 | } 245 | 246 | if !is_optional && is_empty { 247 | let default_uuid_missing_error = "This field is required.".to_string(); 248 | 249 | if let Some(error_handler) = error_handler.clone() { 250 | let uuid_missing_error = UuidFieldError::MissingField(&field_name); 251 | let custom_errors = 252 | error_handler(uuid_missing_error, vec![default_uuid_missing_error]); 253 | errors.extend_from_slice(&custom_errors); 254 | } else { 255 | errors.push(default_uuid_missing_error); 256 | } 257 | } 258 | 259 | if errors.len() > 0 { 260 | return Err(errors); 261 | } 262 | 263 | if is_optional && is_empty { 264 | let value_t = T::from_vec(&mut vec![]); 265 | 266 | if let Some(t) = value_t { 267 | let mut result = result_ref.lock().await; 268 | *result = Some(Box::new(t)); 269 | } 270 | } 271 | 272 | validated.store(true, Ordering::Relaxed); 273 | Ok(()) 274 | })) 275 | } 276 | 277 | fn wrap(&self) -> Box { 278 | Box::new(self.clone()) 279 | } 280 | } 281 | 282 | #[cfg(test)] 283 | pub mod tests { 284 | use crate::core::forms::{Files, FormData}; 285 | use crate::forms::fields::uuid_field::UuidField; 286 | use crate::forms::fields::AbstractFields; 287 | 288 | use uuid::Uuid; 289 | 290 | #[tokio::test] 291 | async fn test_uuid_validate_required() { 292 | let mut uuid_field: UuidField = UuidField::new("key"); 293 | let mut forms_data = FormData::new(); 294 | let mut files = Files::new(); 295 | 296 | let result = uuid_field.validate(&mut forms_data, &mut files).await; 297 | assert_eq!(false, result.is_ok()); 298 | 299 | let mut uuid_field2: UuidField = UuidField::new("key"); 300 | forms_data.insert("key".to_string(), vec!["abcd".to_string()]); 301 | let result = uuid_field2.validate(&mut forms_data, &mut files).await; 302 | assert_eq!(false, result.is_ok()); 303 | 304 | // Clear form field values 305 | forms_data.clear(); 306 | 307 | let mut uuid_field3: UuidField = UuidField::new("key"); 308 | forms_data.insert( 309 | "key".to_string(), 310 | vec!["1130fc58-e9dd-4fce-aa7a-cb41cebdebe1".to_string()], 311 | ); 312 | let result = uuid_field3.validate(&mut forms_data, &mut files).await; 313 | assert_eq!(true, result.is_ok()); 314 | } 315 | 316 | #[tokio::test] 317 | async fn test_uuid_optional() { 318 | let mut uuid_field: UuidField> = UuidField::new("key"); 319 | let mut forms_data = FormData::new(); 320 | let mut files = Files::new(); 321 | let result = uuid_field.validate(&mut forms_data, &mut files).await; 322 | assert_eq!(true, result.is_ok()); 323 | assert_eq!(None, uuid_field.value().await); 324 | } 325 | 326 | #[tokio::test] 327 | async fn test_uuid_vec() { 328 | let mut uuid_field: UuidField> = UuidField::new("key"); 329 | let mut forms_data = FormData::new(); 330 | forms_data.insert( 331 | "key".to_string(), 332 | vec!["1130fc58-e9dd-4fce-aa7a-cb41cebdebe1".to_string()], 333 | ); 334 | let mut files = Files::new(); 335 | let result = uuid_field.validate(&mut forms_data, &mut files).await; 336 | assert_eq!(true, result.is_ok()); 337 | assert_eq!(1, uuid_field.value().await.len()); 338 | } 339 | 340 | #[tokio::test] 341 | async fn test_uuid_optional_vec() { 342 | let mut uuid_field: UuidField>> = UuidField::new("key"); 343 | let mut forms_data = FormData::new(); 344 | let mut files = Files::new(); 345 | let result = uuid_field.validate(&mut forms_data, &mut files).await; 346 | assert_eq!(true, result.is_ok()); 347 | assert_eq!(None, uuid_field.value().await); 348 | } 349 | } 350 | -------------------------------------------------------------------------------- /src/core/websocket/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod frame; 2 | 3 | use std::sync::atomic::{AtomicBool, Ordering}; 4 | use std::sync::Arc; 5 | use std::time::Duration; 6 | 7 | use base64::Engine; 8 | use serde_json::Value; 9 | use sha1::{Digest, Sha1}; 10 | use uuid::Uuid; 11 | 12 | use crate::core::headers::{HeaderValue, Headers}; 13 | use crate::core::request::Request; 14 | use crate::core::response::status::ResponseStatus; 15 | use crate::core::response::{response_to_bytes, AbstractResponse, HttpResponse}; 16 | use crate::core::stream::Stream; 17 | use crate::core::websocket::frame::{reader, Frame}; 18 | use crate::{racoon_debug, racoon_error}; 19 | 20 | const DEFAULT_MAX_PAYLOAD_SIZE: u64 = 5 * 1024 * 1024; // 5 MiB 21 | 22 | pub enum Message { 23 | Continue(Vec), 24 | Text(String), 25 | Binary(Vec), 26 | Close(u16, String), 27 | Ping(), 28 | Pong(), 29 | Others(Vec), 30 | } 31 | 32 | pub struct WebSocket { 33 | pub uid: String, 34 | stream: Arc, 35 | request_validated: bool, 36 | receive_next: Arc, 37 | headers: Headers, 38 | body: Vec, 39 | } 40 | 41 | impl Clone for WebSocket { 42 | fn clone(&self) -> Self { 43 | Self { 44 | uid: self.uid.clone(), 45 | stream: self.stream.clone(), 46 | request_validated: self.request_validated.clone(), 47 | receive_next: self.receive_next.clone(), 48 | headers: self.headers.clone(), 49 | body: self.body.clone(), 50 | } 51 | } 52 | } 53 | 54 | impl AbstractResponse for WebSocket { 55 | fn status(&self) -> (u32, String) { 56 | (200, "OK".to_string()) 57 | } 58 | 59 | fn serve_default(&mut self) -> bool { 60 | false 61 | } 62 | 63 | fn get_headers(&mut self) -> &mut Headers { 64 | &mut self.headers 65 | } 66 | 67 | fn get_body(&mut self) -> &mut Vec { 68 | &mut self.body 69 | } 70 | 71 | fn should_close(&mut self) -> bool { 72 | true 73 | } 74 | } 75 | 76 | impl WebSocket { 77 | pub async fn from(request: &Request) -> (Self, bool) { 78 | Self::from_opt(request, true).await 79 | } 80 | 81 | pub async fn from_opt(request: &Request, periodic_ping: bool) -> (Self, bool) { 82 | let instance = match WebSocket::validate(request).await { 83 | Ok(instance) => instance, 84 | Err(error) => { 85 | racoon_error!("WS Error: {}", error); 86 | 87 | let failed = Self { 88 | uid: Uuid::new_v4().to_string(), 89 | stream: request.stream.clone(), 90 | request_validated: false, 91 | receive_next: Arc::new(AtomicBool::new(true)), 92 | headers: Headers::new(), 93 | body: Vec::new(), 94 | }; 95 | return (failed, false); 96 | } 97 | }; 98 | 99 | if periodic_ping { 100 | instance.ping_with_interval(Duration::from_secs(10)).await; 101 | } 102 | 103 | (instance, true) 104 | } 105 | 106 | async fn validate(request: &Request) -> Result { 107 | if request.method != "GET" { 108 | return Err("Invalid request method.".to_owned()); 109 | } 110 | 111 | // Validate connection header 112 | if let Some(value) = request.headers.value("Connection") { 113 | // Connection header can contain multiple values seperated by comma. 114 | // Checks if 'upgrade' is specified or not. If not returns error. 115 | if !value.to_lowercase().contains("upgrade") { 116 | return Err("Connection header does not specify to upgrade".to_string()); 117 | } 118 | } else { 119 | return Err("Connection header is missing.".to_string()); 120 | } 121 | 122 | let upgrade; 123 | if let Some(value) = request.headers.value("Upgrade") { 124 | upgrade = value; 125 | } else { 126 | return Err("Upgrade header is missing.".to_string()); 127 | }; 128 | 129 | let sec_websocket_key; 130 | if let Some(value) = request.headers.value("Sec-WebSocket-Key") { 131 | // According to RFC, any leading or trailing spaces must be removed. 132 | sec_websocket_key = value.trim().to_string(); 133 | } else { 134 | return Err("Sec-WebSocket-Key header is missing".to_string()); 135 | } 136 | 137 | if upgrade.to_lowercase() == "websocket" { 138 | } else { 139 | return Err("Upgrade header is not set to websocket.".to_string()); 140 | } 141 | 142 | let instance = Self { 143 | uid: Uuid::new_v4().to_string(), 144 | stream: request.stream.clone(), 145 | request_validated: true, 146 | receive_next: Arc::new(AtomicBool::new(false)), 147 | headers: Headers::new(), 148 | body: Vec::new(), 149 | }; 150 | 151 | match Self::handshake(request.stream.clone(), &sec_websocket_key).await { 152 | Ok(()) => {} 153 | Err(error) => { 154 | return Err(format!("Failed to handshake. {}", error)); 155 | } 156 | }; 157 | 158 | instance.receive_next.store(true, Ordering::Relaxed); 159 | Ok(instance) 160 | } 161 | 162 | /// 163 | /// More information: 164 | /// 165 | async fn handshake(stream: Arc, sec_websocket_key: &str) -> std::io::Result<()> { 166 | let base64_hash = Self::handshake_key_base64(sec_websocket_key); 167 | 168 | let mut http_response = HttpResponse::switching_protocols(); 169 | let headers = http_response.get_headers(); 170 | headers.set("Connection", "upgrade"); 171 | headers.set("Upgrade", "websocket"); 172 | headers.set("Sec-WebSocket-Accept", base64_hash.as_bytes()); 173 | 174 | let mut response: Box = http_response.empty(); 175 | let response_bytes = response_to_bytes(&mut response); 176 | Ok(stream.write_chunk(&response_bytes).await?) 177 | } 178 | 179 | fn handshake_key_base64(sec_websocket_key: &str) -> String { 180 | // WebSocket GUID constant 181 | const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 182 | let new_key = format!("{}{}", sec_websocket_key.trim(), WEBSOCKET_GUID); 183 | 184 | // Generates Sha1 hash 185 | let mut hasher = Sha1::new(); 186 | hasher.update(new_key); 187 | let hash_result = hasher.finalize().to_vec(); 188 | 189 | // Encodes to base 64 190 | base64::engine::general_purpose::STANDARD.encode(hash_result) 191 | } 192 | 193 | async fn ping_with_interval(&self, duration: Duration) { 194 | let stream = self.stream.clone(); 195 | let receive_next = self.receive_next.clone(); 196 | 197 | tokio::spawn(async move { 198 | racoon_debug!("Sending periodic ping frames..."); 199 | 200 | let mut interval = tokio::time::interval(duration); 201 | 202 | // More information: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 203 | let frame = Frame { 204 | fin: 1, 205 | op_code: 9, 206 | payload: vec![], 207 | }; 208 | 209 | let bytes = frame::builder::build(&frame); 210 | interval.tick().await; 211 | 212 | loop { 213 | interval.tick().await; 214 | racoon_debug!("Sending ping..."); 215 | 216 | match stream.write_chunk(&bytes).await { 217 | Ok(()) => {} 218 | Err(error) => { 219 | // Ping failed, so if messages are waiting, stops waiting new messages. 220 | receive_next.store(false, Ordering::Relaxed); 221 | racoon_debug!("Ping failed. Error: {}", error); 222 | break; 223 | } 224 | } 225 | } 226 | }); 227 | } 228 | 229 | async fn send_pong(&self) { 230 | racoon_debug!("Sending pong frame."); 231 | 232 | // More information: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 233 | let frame = Frame { 234 | fin: 1, 235 | op_code: 10, 236 | payload: vec![], 237 | }; 238 | 239 | let bytes = frame::builder::build(&frame); 240 | match self.stream.write_chunk(&bytes).await { 241 | Ok(()) => {} 242 | Err(error) => { 243 | // Pong failed, so stops receiving messages. 244 | self.receive_next.store(false, Ordering::Relaxed); 245 | racoon_debug!("Pong failed. Error: {}", error); 246 | } 247 | } 248 | } 249 | 250 | pub async fn receive_message_with_limit(&self, max_payload_size: u64) -> Option { 251 | if !self.receive_next.load(Ordering::Relaxed) { 252 | return None; 253 | }; 254 | 255 | let mut response: Vec = vec![]; 256 | 257 | loop { 258 | let frame = match reader::read_frame(self.stream.clone(), max_payload_size).await { 259 | Ok(frame) => frame, 260 | Err(error) => { 261 | // Stops waiting for new messages 262 | self.receive_next.store(false, Ordering::Relaxed); 263 | return Some(Message::Close(1000, error.to_string())); 264 | } 265 | }; 266 | 267 | response.extend(&frame.payload); 268 | 269 | // Checks response size 270 | if response.len() > DEFAULT_MAX_PAYLOAD_SIZE as usize { 271 | return Some(Message::Close(0, "Max payload size exceed.".to_string())); 272 | } 273 | 274 | // If fin is 1, the complete message is received. 275 | if frame.fin == 1 { 276 | return if frame.op_code == 0 { 277 | Some(Message::Continue(frame.payload)) 278 | } else if frame.op_code == 1 { 279 | // Text Frame 280 | let payload_text = String::from_utf8_lossy(frame.payload.as_slice()); 281 | Some(Message::Text(payload_text.to_string())) 282 | } else if frame.op_code == 2 { 283 | // Binary frame 284 | Some(Message::Binary(frame.payload)) 285 | } else if frame.op_code == 8 { 286 | // Connection close frame 287 | self.receive_next.store(false, Ordering::Relaxed); 288 | let close_code = self.close_code_from_payload(&frame.payload); 289 | let close_message = self.close_message_from_payload(&frame.payload); 290 | Some(Message::Close(close_code, close_message)) 291 | } else if frame.op_code == 9 { 292 | // Ping frame 293 | self.send_pong().await; 294 | Some(Message::Ping()) 295 | } else if frame.op_code == 10 { 296 | // Pong frame 297 | Some(Message::Pong()) 298 | } else { 299 | Some(Message::Others(frame.payload)) 300 | }; 301 | } 302 | } 303 | } 304 | 305 | pub async fn message(&self) -> Option { 306 | self.receive_message_with_limit(DEFAULT_MAX_PAYLOAD_SIZE) 307 | .await 308 | } 309 | 310 | pub async fn send_text>(&self, message: S) -> std::io::Result<()> { 311 | let message = message.as_ref(); 312 | 313 | let frame = Frame { 314 | fin: 1, 315 | op_code: 1, 316 | payload: message.as_bytes().to_vec(), 317 | }; 318 | 319 | let bytes = frame::builder::build(&frame); 320 | self.stream.write_chunk(&bytes).await?; 321 | Ok(()) 322 | } 323 | 324 | pub async fn send_bytes>(&self, bytes: B) -> std::io::Result<()> { 325 | let payload = Vec::from(bytes.as_ref()); 326 | 327 | let frame = Frame { 328 | fin: 1, 329 | op_code: 2, 330 | payload, 331 | }; 332 | 333 | let bytes = frame::builder::build(&frame); 334 | self.stream.write_chunk(&bytes).await?; 335 | 336 | Ok(()) 337 | } 338 | 339 | pub async fn send_json(&self, json: &Value) -> std::io::Result<()> { 340 | self.send_text(json.to_string().as_str()).await 341 | } 342 | 343 | pub async fn bad_request(self) -> Box { 344 | let mut response: Box = 345 | HttpResponse::bad_request().body("Bad Request"); 346 | let response_bytes = response_to_bytes(&mut response); 347 | let _ = self.stream.write_chunk(&response_bytes).await; 348 | Box::new(self) 349 | } 350 | 351 | pub async fn close(&self) { 352 | let _ = self.stream.shutdown().await; 353 | } 354 | 355 | pub fn exit(self) -> Box { 356 | Box::new(self) 357 | } 358 | 359 | fn close_code_from_payload(&self, response: &[u8]) -> u16 { 360 | if response.len() == 2 { 361 | let mut tmp_bytes = [0u8; 2]; 362 | tmp_bytes.copy_from_slice(response); 363 | return u16::from_be_bytes(tmp_bytes); 364 | } 365 | 366 | racoon_debug!( 367 | "Close payload length expected more than 2. But found: {}", 368 | response.len() 369 | ); 370 | return 0; 371 | } 372 | 373 | fn close_message_from_payload(&self, response: &[u8]) -> String { 374 | if response.len() < 3 { 375 | return "No close message specified.".to_string(); 376 | } 377 | 378 | let message_bytes = &response[2..]; 379 | String::from_utf8_lossy(&message_bytes).to_string() 380 | } 381 | } 382 | -------------------------------------------------------------------------------- /src/forms/fields/file_field.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::marker::PhantomData; 3 | use std::path::PathBuf; 4 | use std::sync::atomic::{AtomicBool, Ordering}; 5 | use std::sync::Arc; 6 | 7 | use async_tempfile::TempFile; 8 | use tokio::sync::Mutex; 9 | 10 | use crate::core::forms::{Files, FormData}; 11 | use crate::forms::AbstractFields; 12 | 13 | use crate::forms::fields::FieldResult; 14 | 15 | pub struct UploadedFile { 16 | pub filename: String, 17 | core_file_field: crate::core::forms::FileField, 18 | pub temp_path: PathBuf, 19 | } 20 | 21 | impl UploadedFile { 22 | pub fn from_core_file_field(core_file_field: crate::core::forms::FileField) -> Self { 23 | let temp_path = core_file_field.temp_path.clone(); 24 | let filename = core_file_field.name.clone(); 25 | 26 | Self { 27 | filename, 28 | core_file_field, 29 | temp_path, 30 | } 31 | } 32 | 33 | pub fn core_file_field(&mut self) -> &crate::core::forms::FileField { 34 | &mut self.core_file_field 35 | } 36 | 37 | pub fn from_temp_file>(filename: S, temp_file: TempFile) -> Self { 38 | let filename = filename.as_ref().to_string(); 39 | let core_file_field = crate::core::forms::FileField::from(&filename, temp_file); 40 | let temp_path = core_file_field.temp_path.clone(); 41 | 42 | Self { 43 | filename, 44 | core_file_field, 45 | temp_path, 46 | } 47 | } 48 | } 49 | 50 | pub type PostValidator = Box Result>>; 51 | type BoxResult = Box; 52 | 53 | pub struct FileField { 54 | field_name: String, 55 | result: Arc>>, 56 | post_validator: Option>, 57 | validated: Arc, 58 | phantom: PhantomData, 59 | } 60 | 61 | impl Clone for FileField { 62 | fn clone(&self) -> Self { 63 | Self { 64 | field_name: self.field_name.clone(), 65 | result: self.result.clone(), 66 | post_validator: self.post_validator.clone(), 67 | validated: self.validated.clone(), 68 | phantom: self.phantom.clone(), 69 | } 70 | } 71 | } 72 | 73 | pub trait ToOptionT { 74 | fn from_vec(files: &mut Vec) -> Option 75 | where 76 | Self: Sized; 77 | 78 | fn is_optional() -> bool; 79 | } 80 | 81 | impl ToOptionT for UploadedFile { 82 | fn from_vec(files: &mut Vec) -> Option { 83 | if files.len() > 0 { 84 | let file_field = files.remove(0); 85 | return Some(UploadedFile::from_core_file_field(file_field)); 86 | } 87 | 88 | None 89 | } 90 | 91 | fn is_optional() -> bool { 92 | false 93 | } 94 | } 95 | 96 | impl ToOptionT for Option { 97 | fn from_vec(files: &mut Vec) -> Option { 98 | if files.len() > 0 { 99 | let file_field = files.remove(0); 100 | // Outer Some denotes successful conversion. 101 | return Some(Some(UploadedFile::from_core_file_field(file_field))); 102 | } 103 | 104 | // Return successful conversion but no files are present. So returns actual value as None. 105 | Some(None) 106 | } 107 | 108 | fn is_optional() -> bool { 109 | true 110 | } 111 | } 112 | 113 | impl ToOptionT for Vec { 114 | fn from_vec(files: &mut Vec) -> Option 115 | where 116 | Self: Sized, 117 | { 118 | if files.len() > 0 { 119 | let mut owned_files = vec![]; 120 | 121 | for i in (0..files.len()).rev() { 122 | let uploaded_file = UploadedFile::from_core_file_field(files.remove(i)); 123 | owned_files.insert(0, uploaded_file); 124 | } 125 | 126 | return Some(owned_files); 127 | } 128 | 129 | // Conversion to type T failed. 130 | None 131 | } 132 | 133 | fn is_optional() -> bool { 134 | false 135 | } 136 | } 137 | 138 | impl ToOptionT for Option> { 139 | fn from_vec(files: &mut Vec) -> Option 140 | where 141 | Self: Sized, 142 | { 143 | if files.len() > 0 { 144 | let mut owned_files = vec![]; 145 | 146 | for i in (0..files.len()).rev() { 147 | let uploaded_file = UploadedFile::from_core_file_field(files.remove(i)); 148 | owned_files.insert(0, uploaded_file); 149 | } 150 | 151 | return Some(Some(owned_files)); 152 | } 153 | 154 | // Conversion to type T successful because of optional field. So returns None as result. 155 | Some(None) 156 | } 157 | 158 | fn is_optional() -> bool { 159 | true 160 | } 161 | } 162 | 163 | impl FileField { 164 | pub fn new>(field_name: S) -> Self { 165 | let field_name = field_name.as_ref().to_string(); 166 | Self { 167 | field_name, 168 | result: Arc::new(Mutex::new(None)), 169 | post_validator: None, 170 | validated: Arc::new(AtomicBool::from(false)), 171 | phantom: PhantomData, 172 | } 173 | } 174 | 175 | pub fn post_validate(mut self, callback: fn(T) -> Result>) -> Self { 176 | self.post_validator = Some(Box::new(callback)); 177 | self 178 | } 179 | 180 | pub async fn value(self) -> T { 181 | if !self.validated.load(Ordering::Relaxed) { 182 | panic!("This field is not validated. Please call form.validate() method before accessing value."); 183 | } 184 | 185 | let mut result_ref = self.result.lock().await; 186 | let result = result_ref.take(); 187 | 188 | if let Some(result) = result { 189 | match result.downcast::() { 190 | Ok(t) => { 191 | return *t; 192 | } 193 | 194 | _ => {} 195 | }; 196 | } 197 | 198 | panic!("Unexpected error. Bug in file_field.rs file."); 199 | } 200 | } 201 | 202 | impl AbstractFields for FileField { 203 | fn field_name(&self) -> FieldResult { 204 | let field_name = self.field_name.clone(); 205 | Box::new(Box::pin(async move { field_name })) 206 | } 207 | 208 | fn validate( 209 | &mut self, 210 | _: &mut FormData, 211 | files: &mut Files, 212 | ) -> FieldResult>> { 213 | let files = files.remove(&self.field_name); 214 | let result_ref = self.result.clone(); 215 | let validated = self.validated.clone(); 216 | let post_validator = self.post_validator.clone(); 217 | 218 | Box::new(Box::pin(async move { 219 | let mut errors = vec![]; 220 | 221 | let is_optional = T::is_optional(); 222 | 223 | let is_empty; 224 | 225 | if let Some(mut files) = files { 226 | let mut result = result_ref.lock().await; 227 | is_empty = files.is_empty(); 228 | 229 | if let Some(t) = T::from_vec(&mut files) { 230 | if let Some(post_validator) = post_validator { 231 | match post_validator(t) { 232 | Ok(t) => { 233 | *result = Some(Box::new(t)); 234 | } 235 | Err(custom_errors) => { 236 | errors.extend_from_slice(&custom_errors); 237 | } 238 | } 239 | } else { 240 | *result = Some(Box::new(t)); 241 | } 242 | } 243 | } else { 244 | is_empty = true; 245 | } 246 | 247 | if !is_optional && is_empty { 248 | errors.push("This field is required.".to_string()); 249 | } 250 | 251 | if errors.len() > 0 { 252 | return Err(errors); 253 | } 254 | 255 | if is_optional && is_empty { 256 | let value_t = T::from_vec(&mut vec![]); 257 | if let Some(t) = value_t { 258 | let mut result = result_ref.lock().await; 259 | *result = Some(Box::new(t)); 260 | } 261 | } 262 | 263 | validated.store(true, Ordering::Relaxed); 264 | Ok(()) 265 | })) 266 | } 267 | 268 | fn wrap(&self) -> Box { 269 | Box::new(self.clone()) 270 | } 271 | } 272 | 273 | #[cfg(test)] 274 | pub mod tests { 275 | use async_tempfile::TempFile; 276 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 277 | 278 | use crate::core::forms::{Files, FormData}; 279 | use crate::forms::fields::AbstractFields; 280 | 281 | use super::{FileField, UploadedFile}; 282 | 283 | #[tokio::test] 284 | async fn test_file_optional() { 285 | let mut form_data = FormData::new(); 286 | let mut files = Files::new(); 287 | 288 | let mut file_field: FileField> = FileField::new("file"); 289 | let result = file_field.validate(&mut form_data, &mut files).await; 290 | 291 | assert_eq!(true, result.is_ok()); 292 | } 293 | 294 | #[tokio::test] 295 | async fn test_file_empty() { 296 | let mut form_data = FormData::new(); 297 | let mut files = Files::new(); 298 | 299 | let mut file_field: FileField = FileField::new("file"); 300 | let result = file_field.validate(&mut form_data, &mut files).await; 301 | 302 | assert_eq!(false, result.is_ok()); 303 | } 304 | 305 | #[tokio::test] 306 | async fn test_file_validate() { 307 | let mut form_data = FormData::new(); 308 | let mut files = Files::new(); 309 | 310 | let mut temp_file = TempFile::new().await.unwrap(); 311 | let _ = temp_file.write_all(b"Hello World").await; 312 | 313 | let core_file_field = crate::core::forms::FileField::from( 314 | "file.txt".to_string(), 315 | temp_file, 316 | ); 317 | 318 | let mut file_field: FileField = FileField::new("file"); 319 | files.insert("file".to_string(), vec![core_file_field]); 320 | let result = file_field.validate(&mut form_data, &mut files).await; 321 | 322 | let path_field = file_field.value().await; 323 | let path_buf = path_field.temp_path; 324 | 325 | assert_eq!(true, path_buf.exists()); 326 | assert_eq!(true, result.is_ok()); 327 | 328 | let mut file = tokio::fs::File::open(&path_buf).await.unwrap(); 329 | let mut content = String::new(); 330 | let _ = file.read_to_string(&mut content).await; 331 | assert_eq!("Hello World".to_string(), content); 332 | } 333 | 334 | #[tokio::test] 335 | async fn test_file_validate_vec() { 336 | let mut form_data = FormData::new(); 337 | let mut files = Files::new(); 338 | 339 | let temp_file = TempFile::new().await.unwrap(); 340 | let core_file_field = crate::core::forms::FileField::from("file.txt", temp_file); 341 | 342 | let mut file_field: FileField> = FileField::new("file"); 343 | files.insert("file".to_string(), vec![core_file_field]); 344 | let result = file_field.validate(&mut form_data, &mut files).await; 345 | assert_eq!(true, result.is_ok()); 346 | 347 | let sent_files = file_field.value().await; 348 | assert_eq!(1, sent_files.len()); 349 | } 350 | 351 | #[tokio::test] 352 | async fn test_file_validate_vec_optional() { 353 | let mut form_data = FormData::new(); 354 | let mut files = Files::new(); 355 | 356 | let temp_file = TempFile::new().await.unwrap(); 357 | let core_file_field = crate::core::forms::FileField::from( 358 | "file.txt".to_string(), 359 | temp_file, 360 | ); 361 | 362 | let mut file_field: FileField>> = FileField::new("file"); 363 | files.insert("file".to_string(), vec![core_file_field]); 364 | let result = file_field.validate(&mut form_data, &mut files).await; 365 | assert_eq!(true, result.is_ok()); 366 | 367 | let sent_files = file_field.value().await; 368 | assert_eq!(true, sent_files.is_some()); 369 | assert_eq!(1, sent_files.unwrap().len()); 370 | 371 | // Empty test 372 | 373 | let mut form_data = FormData::new(); 374 | let mut files = Files::new(); 375 | 376 | let mut file_field: FileField>> = FileField::new("file"); 377 | let result = file_field.validate(&mut form_data, &mut files).await; 378 | assert_eq!(true, result.is_ok()); 379 | assert_eq!(false, file_field.value().await.is_some()); 380 | } 381 | 382 | #[tokio::test] 383 | async fn test_post_validate() { 384 | let mut form_data = FormData::new(); 385 | let mut files = Files::new(); 386 | 387 | let temp_file = TempFile::new().await.unwrap(); 388 | let core_file_field = crate::core::forms::FileField::from( 389 | "file.txt".to_string(), 390 | temp_file, 391 | ); 392 | 393 | let mut file_field: FileField = 394 | FileField::new("file").post_validate(|file| { 395 | if !file.filename.eq("file2.txt") { 396 | return Err(vec!["File name does not equal file2.txt".to_string()]); 397 | } 398 | 399 | Ok(file) 400 | }); 401 | files.insert("file".to_string(), vec![core_file_field]); 402 | let result = file_field.validate(&mut form_data, &mut files).await; 403 | assert_eq!(false, result.is_ok()); 404 | } 405 | } 406 | -------------------------------------------------------------------------------- /src/core/session/managers.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::path::PathBuf; 3 | use std::str::FromStr; 4 | use std::sync::Arc; 5 | 6 | use sqlx::sqlite::SqliteConnectOptions; 7 | use sqlx::ConnectOptions; 8 | use sqlx::Executor; 9 | use sqlx::Pool; 10 | use sqlx::Sqlite; 11 | 12 | use crate::core::session::AbstractSessionManager; 13 | use crate::core::session::SessionResult; 14 | use crate::racoon_debug; 15 | use crate::racoon_error; 16 | 17 | /// 18 | /// FileSessionManager is a default session manager based on the Sqlite database. The database is stored on 19 | /// `.cache/session` file. 20 | /// 21 | /// # Examples 22 | /// 23 | /// ``` 24 | /// use std::env; 25 | /// 26 | /// use racoon::core::session::managers::FileSessionManager; 27 | /// 28 | /// #[tokio::main] 29 | /// async fn main() { 30 | /// // Optional 31 | /// env::set_var("SESSION_FILE_PATH", "../mydb/session"); 32 | /// let session_manager = FileSessionManager::new().await; 33 | /// } 34 | /// ``` 35 | /// 36 | /// The file path can be specified by specifying `SESSION_FILE_PATH` in environment variable. 37 | /// 38 | pub struct FileSessionManager { 39 | db_connection: Arc>>, 40 | } 41 | 42 | impl FileSessionManager { 43 | /// 44 | /// Creates new instance of FileSessonManager. 45 | /// 46 | pub async fn new() -> std::io::Result { 47 | let instance = Self { 48 | db_connection: Arc::new(None), 49 | }; 50 | Ok(instance) 51 | } 52 | 53 | /// 54 | /// Returns stored session file path. 55 | /// 56 | /// If environment variable `SESSION_FILE_PATH` is specified, it will return the specified path 57 | /// else default relative file path `.cache/session`. 58 | /// 59 | pub fn get_db_path() -> String { 60 | let is_test = env::var("TEST_SESSION").unwrap_or("false".to_string()); 61 | if is_test.to_lowercase() == "true" { 62 | // Returns Sqlite path for testing 63 | racoon_debug!("Using test session database."); 64 | return ".cache/test_session".to_string(); 65 | } 66 | 67 | env::var("SESSION_FILE_PATH").unwrap_or(".cache/session".to_string()) 68 | } 69 | 70 | /// 71 | /// Returns Sqlite pool lazily. If connection pool is not already initialized, it initializes 72 | /// new Sqlite database, creates table and returns the new initialized connection pool. 73 | /// 74 | async fn lazy_connection_pool( 75 | mut db_connection: Arc>>, 76 | ) -> std::io::Result> { 77 | if let Some(db_pool) = db_connection.as_ref() { 78 | return Ok(db_pool.clone()); 79 | } 80 | 81 | let db_path = PathBuf::from(FileSessionManager::get_db_path()); 82 | let db_exists; 83 | 84 | if !db_path.exists() { 85 | racoon_debug!("Session database does not exist. Creating new one."); 86 | 87 | // Session database directory 88 | let mut db_dir = db_path.clone(); 89 | db_dir.pop(); 90 | 91 | db_exists = false; 92 | std::fs::create_dir_all(db_dir)?; 93 | std::fs::File::create_new(&db_path)?; 94 | } else { 95 | db_exists = true; 96 | } 97 | 98 | // Disables sqlx logging 99 | let connect_options = 100 | match SqliteConnectOptions::from_str(db_path.to_string_lossy().as_ref()) { 101 | Ok(options) => options.disable_statement_logging(), 102 | Err(error) => { 103 | return Err(std::io::Error::other(format!( 104 | "Failed to create sqlite connect options for session database. Error: {}", 105 | error 106 | ))); 107 | } 108 | }; 109 | 110 | match sqlx::SqlitePool::connect_with(connect_options).await { 111 | Ok(pool) => { 112 | if !db_exists { 113 | const CREATE_SESSION_TABLE_QUERY: &str = r#" 114 | CREATE TABLE session( 115 | id BIGINT AUTO_INCREMENT PRIMARY KEY, 116 | session_id VARCHAR(1025) NOT NULL, 117 | key TEXT NOT NULL UNIQUE, 118 | value TEXT NOT NULL 119 | ) 120 | "#; 121 | 122 | match pool.execute(CREATE_SESSION_TABLE_QUERY).await { 123 | Ok(_) => { 124 | racoon_debug!("Created session table."); 125 | } 126 | Err(error) => { 127 | return Err(std::io::Error::other(format!( 128 | "Failed to create session table. Error: {}", 129 | error 130 | ))); 131 | } 132 | }; 133 | } 134 | db_connection = Arc::from(Some(pool.clone())); 135 | 136 | if let Some(db_connection) = db_connection.as_ref() { 137 | return Ok(db_connection.clone()); 138 | } 139 | 140 | return Err(std::io::Error::other("Error reading connection pool.")); 141 | } 142 | Err(error) => { 143 | return Err(std::io::Error::other(format!( 144 | "Failed to connect sqlite db for managing session. Error: {:?}", 145 | error 146 | ))); 147 | } 148 | } 149 | } 150 | } 151 | 152 | impl AbstractSessionManager for FileSessionManager { 153 | fn set( 154 | &self, 155 | session_id: &String, 156 | name: &str, 157 | value: &str, 158 | ) -> SessionResult> { 159 | let db_connection = self.db_connection.clone(); 160 | let session_id = session_id.to_owned(); 161 | let key = name.to_string(); 162 | let value = value.to_string(); 163 | 164 | Box::new(Box::pin(async move { 165 | let db_pool = match Self::lazy_connection_pool(db_connection.clone()).await { 166 | Ok(pool) => pool, 167 | Err(error) => { 168 | return Err(error); 169 | } 170 | }; 171 | 172 | const UPSERT_QUERY: &str = r#" 173 | INSERT INTO session(session_id, key, value) 174 | VALUES ($1, $2, $3) 175 | ON CONFLICT(key) DO UPDATE 176 | SET 177 | session_id=excluded.session_id, 178 | key=excluded.key, 179 | value=excluded.value 180 | "#; 181 | 182 | let result = sqlx::query(UPSERT_QUERY) 183 | .bind(session_id) 184 | .bind(key) 185 | .bind(value) 186 | .execute(&db_pool) 187 | .await; 188 | 189 | match result { 190 | Ok(_) => {} 191 | Err(error) => { 192 | return Err(std::io::Error::other(format!( 193 | "Failed to set session value. Error: {}", 194 | error 195 | ))); 196 | } 197 | }; 198 | 199 | Ok(()) 200 | })) 201 | } 202 | 203 | fn get(&self, session_id: &String, name: &str) -> SessionResult> { 204 | let db_connection = self.db_connection.clone(); 205 | let session_id = session_id.to_owned(); 206 | let key = name.to_string(); 207 | 208 | Box::new(Box::pin(async move { 209 | let db_pool = match Self::lazy_connection_pool(db_connection.clone()).await { 210 | Ok(pool) => pool, 211 | Err(error) => { 212 | racoon_error!( 213 | "Failed to create session database connection pool. Error: {}", 214 | error 215 | ); 216 | return None; 217 | } 218 | }; 219 | 220 | const FETCH_QUERY: &str = r#" 221 | SELECT value FROM session 222 | WHERE 223 | session_id=$1 AND key=$2 224 | LIMIT 1 225 | "#; 226 | 227 | let result: Result<(String,), sqlx::Error> = sqlx::query_as(FETCH_QUERY) 228 | .bind(session_id) 229 | .bind(key) 230 | .fetch_one(&db_pool) 231 | .await; 232 | 233 | return match result { 234 | Ok((value,)) => Some(value), 235 | Err(error) => { 236 | racoon_debug!("Failed to fetch session value. Error: {}", error); 237 | return None; 238 | } 239 | }; 240 | })) 241 | } 242 | 243 | fn remove(&self, session_id: &String, name: &str) -> SessionResult> { 244 | let db_connection = self.db_connection.clone(); 245 | let session_id = session_id.to_owned(); 246 | let key = name.to_string(); 247 | 248 | Box::new(Box::pin(async move { 249 | let db_pool = match Self::lazy_connection_pool(db_connection.clone()).await { 250 | Ok(pool) => pool, 251 | Err(error) => { 252 | return Err(std::io::Error::other(format!( 253 | "Failed to create session database connection pool. Error: {}", 254 | error 255 | ))); 256 | } 257 | }; 258 | 259 | const DELETE_QUERY: &str = r#" 260 | DELETE FROM session WHERE session_id=$1 AND key=$2 261 | "#; 262 | 263 | let result = sqlx::query(DELETE_QUERY) 264 | .bind(session_id) 265 | .bind(key) 266 | .execute(&db_pool) 267 | .await; 268 | 269 | return match result { 270 | Ok(_) => Ok(()), 271 | Err(error) => Err(std::io::Error::other(format!( 272 | "Failed to delete session values. Error: {}", 273 | error 274 | ))), 275 | }; 276 | })) 277 | } 278 | 279 | fn destroy(&self, session_id: &String) -> SessionResult> { 280 | let db_connection = self.db_connection.clone(); 281 | let session_id = session_id.to_owned(); 282 | 283 | Box::new(Box::pin(async move { 284 | let db_pool = match Self::lazy_connection_pool(db_connection.clone()).await { 285 | Ok(pool) => pool, 286 | Err(error) => { 287 | return Err(std::io::Error::other(format!( 288 | "Failed to create session database connection pool. Error: {}", 289 | error 290 | ))); 291 | } 292 | }; 293 | 294 | const DELETE_QUERY: &str = r#" 295 | DELETE FROM session WHERE session_id=$1 296 | "#; 297 | 298 | let result = sqlx::query(DELETE_QUERY) 299 | .bind(session_id) 300 | .execute(&db_pool) 301 | .await; 302 | 303 | return match result { 304 | Ok(_) => Ok(()), 305 | Err(error) => Err(std::io::Error::other(format!( 306 | "Failed to delete all session values. Error: {}", 307 | error 308 | ))), 309 | }; 310 | })) 311 | } 312 | } 313 | 314 | #[cfg(test)] 315 | pub mod test { 316 | use std::{env, path::PathBuf, str::FromStr}; 317 | 318 | use uuid::Uuid; 319 | 320 | use crate::core::session::AbstractSessionManager; 321 | 322 | use super::FileSessionManager; 323 | 324 | #[tokio::test] 325 | async fn test_file_session() { 326 | // Specifies to use seperate testing database for session 327 | env::set_var("TEST_SESSION", "true"); 328 | let db_path = FileSessionManager::get_db_path(); 329 | assert_eq!(db_path, ".cache/test_session"); 330 | 331 | // Removes existing database file if any 332 | if PathBuf::from_str(&db_path).unwrap().exists() { 333 | let result = tokio::fs::remove_file(&db_path).await; 334 | assert_eq!(true, result.is_ok()); 335 | } 336 | 337 | let session_manager_result = FileSessionManager::new().await; 338 | assert_eq!(true, session_manager_result.is_ok()); 339 | 340 | let session_manager = session_manager_result.unwrap(); 341 | let session_id = Uuid::new_v4().to_string(); 342 | 343 | // tests insert 344 | let result = session_manager.set(&session_id, "name", "John").await; 345 | let result2 = session_manager.set(&session_id, "location", "ktm").await; 346 | assert_eq!(true, result.is_ok()); 347 | assert_eq!(true, result2.is_ok()); 348 | 349 | let name = session_manager.get(&session_id, "name").await; 350 | assert_eq!(Some("John".to_string()), name); 351 | 352 | let location = session_manager.get(&session_id, "location").await; 353 | assert_eq!(Some("ktm".to_string()), location); 354 | 355 | // tests removal 356 | let delete_name_result = session_manager.remove(&session_id, "name").await; 357 | assert_eq!(true, delete_name_result.is_ok()); 358 | 359 | let unknown = session_manager.get(&session_id, "name").await; 360 | assert_eq!(None, unknown); 361 | 362 | let name = session_manager.get(&session_id, "name").await; 363 | assert_eq!(None, name); 364 | 365 | // tests destory 366 | let destroy_result = session_manager.destroy(&session_id).await; 367 | assert_eq!(true, destroy_result.is_ok()); 368 | 369 | let location = session_manager.get(&session_id, "location").await; 370 | assert_eq!(None, location); 371 | 372 | let delete_db_result = tokio::fs::remove_file(db_path).await; 373 | assert_eq!(true, delete_db_result.is_ok()); 374 | } 375 | } 376 | -------------------------------------------------------------------------------- /src/core/websocket/frame.rs: -------------------------------------------------------------------------------- 1 | /// 2 | /// Protocol format: 3 | /// 4 | /// ```markdown 5 | /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 6 | /// +-+-+-+-+-------+-+-------------+-------------------------------+ 7 | /// |F|R|R|R| opcode|M| Payload len | Extended payload length | 8 | /// |I|S|S|S| (4) |A| (7) | (16/64) | 9 | /// |N|V|V|V| |S| | (if payload len==126/127) | 10 | /// | |1|2|3| |K| | | 11 | /// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + 12 | /// | Extended payload length continued, if payload len == 127 | 13 | /// + - - - - - - - - - - - - - - - +-------------------------------+ 14 | /// | |Masking-key, if MASK set to 1 | 15 | /// +-------------------------------+-------------------------------+ 16 | /// | Masking-key (continued) | Payload Data | 17 | /// +-------------------------------- - - - - - - - - - - - - - - - + 18 | /// : Payload Data continued ... : 19 | /// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + 20 | /// | Payload Data continued ... | 21 | /// +---------------------------------------------------------------+ 22 | /// ``` 23 | /// 24 | /// More information: 25 | /// 26 | pub struct Frame { 27 | pub fin: u8, 28 | pub op_code: u8, 29 | pub payload: Vec, 30 | } 31 | 32 | pub mod reader { 33 | use std::sync::Arc; 34 | 35 | use crate::core::stream::Stream; 36 | use crate::core::websocket::frame::Frame; 37 | 38 | use crate::racoon_debug; 39 | 40 | pub async fn read_frame(stream: Arc, max_payload_size: u64) -> std::io::Result { 41 | let mut buffer = vec![]; 42 | 43 | // Reads first 16 bits including FIN, RSV(1, 2, 3), OPCODE and Payload length 44 | while buffer.len() < 2 { 45 | let chunk = stream.read_chunk().await?; 46 | buffer.extend(chunk); 47 | } 48 | 49 | let first_byte = buffer[0]; 50 | let fin = fin_bit_to_u8(&first_byte); 51 | let op_code = opcode_bit_to_u8(&first_byte); 52 | 53 | // 1 bit mask and 7 bit payload length 54 | let second_byte = buffer[1]; 55 | let mask_bit = bit_mask_to_u8(&second_byte); 56 | 57 | let payload_length = payload_length_to_u8(&second_byte); 58 | 59 | // Removes two bytes read from the buffer 60 | buffer.drain(0..2); 61 | 62 | // If length is between 0-125, this is the actual length of the message else actual length is 63 | // set in the next 8 bytes. 64 | let actual_payload_length: u64; 65 | if payload_length < 126 { 66 | actual_payload_length = payload_length as u64; 67 | } else if payload_length == 126 { 68 | // For 127 payload length, actual size is in next two bytes. 69 | while buffer.len() < 2 { 70 | let chunk = stream.read_chunk().await?; 71 | buffer.extend(chunk); 72 | } 73 | 74 | actual_payload_length = payload_length_to_u16(&buffer[..2])? as u64; 75 | 76 | // Removes used 2 bytes 77 | buffer.drain(0..2); 78 | } else { 79 | // For more than 126 payload length, actual size is in next 8 bytes. 80 | while buffer.len() < 8 { 81 | let chunk = stream.read_chunk().await?; 82 | buffer.extend(chunk); 83 | } 84 | 85 | actual_payload_length = payload_length_to_u64(&buffer[..8])?; 86 | 87 | // Removes used 8 bytes 88 | buffer.drain(0..8); 89 | } 90 | 91 | let masking_key: Option>; 92 | 93 | if mask_bit == 1 { 94 | // Bit mask bit is set to 1, so extracts masking key of 4 bytes. 95 | if buffer.len() < 4 { 96 | let chunk = stream.read_chunk().await?; 97 | buffer.extend(chunk); 98 | } 99 | 100 | let key = (&buffer[..4]).to_owned(); 101 | masking_key = Some(key); 102 | 103 | // Removes read 4 bytes from the buffer 104 | buffer.drain(0..4); 105 | 106 | racoon_debug!("Websocket masking key: {:?}.", &masking_key); 107 | } else { 108 | racoon_debug!("Websocket masking disabled."); 109 | masking_key = None; 110 | } 111 | 112 | if actual_payload_length > max_payload_size { 113 | return Err(std::io::Error::other( 114 | "Payload length is more than the maximum allowed size.", 115 | )); 116 | } 117 | 118 | // Loads message bytes to the buffer 119 | while buffer.len() < actual_payload_length as usize { 120 | let chunk = stream.read_chunk().await?; 121 | buffer.extend(chunk); 122 | } 123 | 124 | // Misread bytes 125 | let extra_read: Vec = buffer.drain(actual_payload_length as usize..).collect(); 126 | 127 | // Decodes websocket message using masking bit 128 | if let Some(masking_key) = masking_key { 129 | // Masking key is 4 bit 130 | for i in 0..buffer.len() { 131 | let masking_byte_index = i % 4; 132 | buffer[i] = buffer[i] ^ &masking_key[masking_byte_index]; 133 | } 134 | } 135 | 136 | if extra_read.len() > 0 { 137 | let _ = stream.restore_payload(&extra_read).await; 138 | } 139 | 140 | Ok(Frame { 141 | fin, 142 | op_code, 143 | payload: buffer, 144 | }) 145 | } 146 | 147 | /// 148 | /// Converts final bit value to unsigned number. 149 | /// 150 | fn fin_bit_to_u8(byte: &u8) -> u8 { 151 | byte >> 7 152 | } 153 | 154 | /// 155 | /// Converts 4 bit opcode to unsigned number. 156 | /// 157 | fn opcode_bit_to_u8(byte: &u8) -> u8 { 158 | byte & 0b00001111 159 | } 160 | 161 | /// 162 | /// Converts the bits value 1 or 0 to unsigned number. 163 | /// 164 | fn bit_mask_to_u8(byte: &u8) -> u8 { 165 | byte >> 7 166 | } 167 | 168 | /// 169 | /// Converts 7 bits to unsigned number. 170 | /// 171 | fn payload_length_to_u8(byte: &u8) -> u8 { 172 | byte & 0b01111111 173 | } 174 | 175 | /// 176 | /// Converts 2 bytes array to unsigned number. 177 | /// 178 | fn payload_length_to_u16(bytes: &[u8]) -> std::io::Result { 179 | if bytes.len() != 2 { 180 | return Err(std::io::Error::other(format!( 181 | "Failed to convert payload length to u64. Bytes of size 2 is expected. But found: {}", 182 | bytes.len() 183 | ))); 184 | } 185 | 186 | let mut tmp_bytes = [0; 2]; 187 | tmp_bytes.copy_from_slice(bytes); 188 | Ok(u16::from_be_bytes(tmp_bytes)) 189 | } 190 | 191 | /// 192 | /// Converts 8 bytes array to unsigned number. 193 | /// 194 | fn payload_length_to_u64(bytes: &[u8]) -> std::io::Result { 195 | if bytes.len() != 8 { 196 | return Err(std::io::Error::other(format!( 197 | "Failed to convert payload length to u64. Bytes of size 8 is expected. But found: {}", 198 | bytes.len() 199 | ))); 200 | } 201 | 202 | let mut tmp_bytes = [0; 8]; 203 | tmp_bytes.copy_from_slice(bytes); 204 | Ok(u64::from_be_bytes(tmp_bytes)) 205 | } 206 | 207 | #[cfg(test)] 208 | pub mod test { 209 | use std::sync::Arc; 210 | 211 | use crate::core::stream::{AbstractStream, TestStreamWrapper}; 212 | use crate::core::websocket::frame::{builder, Frame}; 213 | 214 | #[tokio::test] 215 | async fn test_read_single_frame() { 216 | let frame = Frame { 217 | fin: 1, 218 | op_code: 1, 219 | payload: "Hello World".as_bytes().to_vec(), 220 | }; 221 | 222 | let frame_bytes = builder::build(&frame); 223 | 224 | let test_stream_wrapper = TestStreamWrapper::new(frame_bytes, 1024); 225 | let stream: Arc> = 226 | Arc::new(Box::new(test_stream_wrapper)); 227 | let result = super::read_frame(stream, 500).await; 228 | 229 | assert_eq!(true, result.is_ok()); 230 | let decoded_frame = result.unwrap(); 231 | 232 | assert_eq!(frame.fin, decoded_frame.fin); 233 | assert_eq!(frame.op_code, decoded_frame.op_code); 234 | assert_eq!(frame.payload, decoded_frame.payload); 235 | } 236 | 237 | #[tokio::test] 238 | async fn test_read_multiple_frames() { 239 | let frame = Frame { 240 | fin: 1, 241 | op_code: 1, 242 | payload: "Hello World".as_bytes().to_vec(), 243 | }; 244 | 245 | let text_frame_bytes = builder::build_opt(&frame, true); 246 | 247 | let frame2 = Frame { 248 | fin: 1, 249 | op_code: 9, 250 | payload: "PING".as_bytes().to_vec(), 251 | }; 252 | let ping_frame_bytes = builder::build_opt(&frame2, true); 253 | 254 | let mut multiple_frame_bytes = text_frame_bytes; 255 | multiple_frame_bytes.extend(&ping_frame_bytes); 256 | 257 | let test_stream_wrapper = TestStreamWrapper::new(multiple_frame_bytes, 1024); 258 | let stream: Arc> = 259 | Arc::new(Box::new(test_stream_wrapper)); 260 | 261 | let result1 = super::read_frame(stream.clone(), 500).await; 262 | 263 | // Check text frame 264 | assert_eq!(true, result1.is_ok()); 265 | let decoded_frame = result1.unwrap(); 266 | 267 | assert_eq!(frame.fin, decoded_frame.fin); 268 | assert_eq!(frame.op_code, decoded_frame.op_code); 269 | assert_eq!(frame.payload, decoded_frame.payload); 270 | 271 | // Check ping frame 272 | let result2 = super::read_frame(stream, 500).await; 273 | 274 | // Check text frame 275 | assert_eq!(true, result2.is_ok()); 276 | let decoded_frame2 = result2.unwrap(); 277 | 278 | assert_eq!(frame2.fin, decoded_frame2.fin); 279 | assert_eq!(frame2.op_code, decoded_frame2.op_code); 280 | assert_eq!(frame2.payload, decoded_frame2.payload); 281 | } 282 | } 283 | } 284 | 285 | pub mod builder { 286 | use crate::core::websocket::frame::Frame; 287 | 288 | pub fn build_opt(frame: &Frame, mask: bool) -> Vec { 289 | let mut buffer: Vec = vec![]; 290 | 291 | // Moves fin byte towards MSB 292 | let fin_byte = frame.fin << 7; 293 | let opcode_byte = frame.op_code; 294 | let first_byte = fin_byte | opcode_byte; 295 | buffer.push(first_byte); 296 | 297 | let actual_payload_length = frame.payload.len(); 298 | 299 | // Calculate the length representation and push it to the buffer 300 | if actual_payload_length < 126 { 301 | let mut second_byte = actual_payload_length as u8; 302 | 303 | if mask { 304 | // Adds 1 to MSB 305 | second_byte = second_byte | 0b10000000; 306 | } 307 | 308 | buffer.push(second_byte); 309 | } else if actual_payload_length < (2_usize.pow(16)) { 310 | // Payload length is between 126 and 65535 bytes 311 | buffer.push(126); // Indicates length is in next 2 bytes 312 | 313 | // Convert the length to 2 bytes and push them 314 | let length_bytes: [u8; 2] = (actual_payload_length as u16).to_be_bytes(); 315 | buffer.extend_from_slice(&length_bytes); 316 | } else { 317 | // Payload length is greater than or equal to 65536 bytes 318 | buffer.push(127); // Indicates length is in next 8 bytes 319 | 320 | // Convert the length to 8 bytes and push them 321 | let length_bytes: [u8; 8] = (actual_payload_length as u64).to_be_bytes(); 322 | buffer.extend_from_slice(&length_bytes); 323 | } 324 | 325 | let mut payload = frame.payload.clone(); 326 | 327 | if mask { 328 | let mask_bytes:[u8; 4] = rand::random(); 329 | buffer.extend_from_slice(&mask_bytes); 330 | 331 | for i in 0..frame.payload.len() { 332 | let mask_index = i % 4; 333 | payload[i] = (frame.payload[i] as usize ^ mask_bytes[mask_index] as usize) as u8; 334 | } 335 | } 336 | 337 | // Append the payload data to the buffer 338 | buffer.extend_from_slice(&payload); 339 | buffer 340 | } 341 | 342 | pub fn build(frame: &Frame) -> Vec { 343 | // Disables masking message for server to client. 344 | build_opt(frame, false) 345 | } 346 | 347 | #[cfg(test)] 348 | pub mod test { 349 | use std::sync::Arc; 350 | 351 | use crate::core::stream::{AbstractStream, TestStreamWrapper}; 352 | use crate::core::websocket::frame::reader::read_frame; 353 | use crate::core::websocket::frame::Frame; 354 | 355 | use super::build_opt; 356 | 357 | #[tokio::test] 358 | async fn test_frame_build_server() { 359 | let frame = Frame { 360 | fin: 0, 361 | op_code: 1, 362 | payload: "Hello World".as_bytes().to_vec(), 363 | }; 364 | 365 | let frame_bytes = build_opt(&frame, true); 366 | 367 | let test_stream_wrapper = TestStreamWrapper::new(frame_bytes, 1024); 368 | let stream: Arc> = 369 | Arc::new(Box::new(test_stream_wrapper)); 370 | 371 | let reader = read_frame(stream, 1000).await; 372 | assert_eq!(true, reader.is_ok()); 373 | 374 | let frame = reader.unwrap(); 375 | assert_eq!(frame.fin, 0); 376 | assert_eq!(frame.op_code, 1); 377 | assert_eq!(frame.payload, "Hello World".as_bytes().to_vec()); 378 | } 379 | } 380 | } 381 | -------------------------------------------------------------------------------- /src/core/stream/mod.rs: -------------------------------------------------------------------------------- 1 | use std::future::Future; 2 | use std::io::ErrorKind; 3 | use std::net::{Shutdown, SocketAddr}; 4 | use std::sync::atomic::{AtomicBool, Ordering}; 5 | use std::sync::Arc; 6 | 7 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 8 | use tokio::io::{ReadHalf, WriteHalf}; 9 | use tokio::net::{TcpStream, UnixStream}; 10 | use tokio::sync::Mutex; 11 | 12 | use tokio_rustls::server::TlsStream; 13 | use tokio_rustls::TlsAcceptor; 14 | 15 | use crate::{racoon_debug, racoon_error}; 16 | 17 | pub type StreamResult<'a, T> = Box + Sync + Send + Unpin + 'a>; 18 | pub type Stream = Box; 19 | 20 | pub trait AbstractStream: Sync + Send { 21 | fn buffer_size(&self) -> StreamResult; 22 | fn peer_addr(&self) -> StreamResult>; 23 | fn restore_payload<'a>(&'a self, bytes: &[u8]) -> StreamResult>; 24 | fn restored_len(&self) -> StreamResult; 25 | fn read_chunk(&self) -> StreamResult>>; 26 | fn write_chunk<'a>(&'a self, bytes: &'a [u8]) -> StreamResult>; 27 | fn shutdown(&self) -> StreamResult>; 28 | } 29 | 30 | #[derive(Debug)] 31 | pub struct TcpStreamWrapper { 32 | stream: Arc>, 33 | reader: Arc>>, 34 | writer: Arc>>, 35 | buffer_size: usize, 36 | restored_payload: Arc>>>, 37 | } 38 | 39 | impl TcpStreamWrapper { 40 | pub fn from(tcp_stream: TcpStream, buffer_size: usize) -> std::io::Result { 41 | // May return "Too many open files error" if all file descriptors are used. 42 | let std_tcp_stream = tcp_stream.into_std()?; 43 | 44 | let async_tcp_stream_rw = match std_tcp_stream.try_clone() { 45 | Ok(std_stream) => TcpStream::from_std(std_stream)?, 46 | Err(err) => { 47 | racoon_error!("Failed to clone std TcpStream to tokio TcpStream. Try increasing file descriptor limit."); 48 | racoon_debug!("Shutting down std stream."); 49 | let shutdown_result = std_tcp_stream.shutdown(std::net::Shutdown::Both); 50 | racoon_debug!("Shutdown result: {:?}", shutdown_result); 51 | return Err(err); 52 | } 53 | }; 54 | 55 | // Stream for shutting down later 56 | let (reader, writer) = tokio::io::split(async_tcp_stream_rw); 57 | let async_tcp_stream = TcpStream::from_std(std_tcp_stream)?; 58 | 59 | Ok(Self { 60 | stream: Arc::new(Mutex::new(async_tcp_stream)), 61 | reader: Arc::new(Mutex::new(reader)), 62 | writer: Arc::new(Mutex::new(writer)), 63 | buffer_size, 64 | restored_payload: Arc::new(Mutex::new(None)), 65 | }) 66 | } 67 | } 68 | 69 | impl AbstractStream for TcpStreamWrapper { 70 | fn buffer_size(&self) -> StreamResult { 71 | let buffer_size = self.buffer_size.clone(); 72 | Box::new(Box::pin(async move { buffer_size })) 73 | } 74 | 75 | fn peer_addr(&self) -> StreamResult> { 76 | let stream_ref = self.stream.clone(); 77 | 78 | Box::new(Box::pin(async move { 79 | let stream = stream_ref.lock().await; 80 | 81 | return match stream.peer_addr() { 82 | Ok(addr) => Some(addr), 83 | Err(error) => { 84 | racoon_debug!("Failed to get peer addr. Error: {}", error); 85 | None 86 | } 87 | }; 88 | })) 89 | } 90 | 91 | fn restore_payload<'a>(&'a self, bytes: &[u8]) -> StreamResult> { 92 | let restored_payload_ref = self.restored_payload.clone(); 93 | let bytes = bytes.to_vec(); 94 | 95 | Box::new(Box::pin(async move { 96 | let mut restored_payload = restored_payload_ref.lock().await; 97 | *restored_payload = Some(bytes); 98 | Ok(()) 99 | })) 100 | } 101 | 102 | fn restored_len(&self) -> StreamResult { 103 | let restored_payload_ref = self.restored_payload.clone(); 104 | 105 | Box::new(Box::pin(async move { 106 | let restored_payload = restored_payload_ref.lock().await; 107 | 108 | if let Some(restored) = restored_payload.as_ref() { 109 | return restored.len(); 110 | } 111 | 112 | 0 113 | })) 114 | } 115 | 116 | fn read_chunk(&self) -> StreamResult>> { 117 | let restored_payload_ref = self.restored_payload.clone(); 118 | let reader_ref = self.reader.clone(); 119 | let buffer_size = self.buffer_size.clone(); 120 | 121 | Box::new(Box::pin(async move { 122 | // If payload of some bytes is restored after reading the chunk, returns the same bytes 123 | // back to the reader again. 124 | // Reading from stream wrapper is skipped because there may not be any bytes to read. 125 | let mut restored_payload = restored_payload_ref.lock().await; 126 | 127 | if let Some(payload) = restored_payload.take() { 128 | // Leaves None 129 | return Ok(payload); 130 | } 131 | 132 | let mut buffer = vec![0u8; buffer_size]; 133 | let mut reader = reader_ref.lock().await; 134 | 135 | return match reader.read(&mut buffer).await { 136 | Ok(read_size) => { 137 | if read_size == 0 { 138 | return Err(std::io::Error::new( 139 | ErrorKind::BrokenPipe, 140 | "Read size is 0. Probably connection broken.", 141 | )); 142 | } 143 | 144 | let chunk: Vec = buffer.drain(0..read_size).collect(); 145 | Ok(chunk) 146 | } 147 | Err(error) => Err(std::io::Error::other(error)), 148 | }; 149 | })) 150 | } 151 | 152 | fn write_chunk<'a>(&'a self, data: &'a [u8]) -> StreamResult> { 153 | let writer_ref = self.writer.clone(); 154 | 155 | Box::new(Box::pin(async move { 156 | let mut writer = writer_ref.lock().await; 157 | writer.write_all(&data).await?; 158 | Ok(()) 159 | })) 160 | } 161 | 162 | fn shutdown(&self) -> StreamResult> { 163 | let stream_ref = self.stream.clone(); 164 | 165 | Box::new(Box::pin(async move { 166 | let mut stream = stream_ref.lock().await; 167 | let _ = stream.shutdown().await; 168 | Ok(()) 169 | })) 170 | } 171 | } 172 | 173 | #[derive(Debug)] 174 | pub struct UnixStreamWrapper { 175 | stream: Arc>, 176 | reader: Arc>>, 177 | writer: Arc>>, 178 | buffer_size: usize, 179 | restored_payload: Arc>>>, 180 | } 181 | 182 | impl UnixStreamWrapper { 183 | pub fn from(unix_stream: UnixStream, buffer_size: usize) -> std::io::Result { 184 | let std_unix_stream = unix_stream.into_std()?; 185 | 186 | let async_unix_stream = match std_unix_stream.try_clone() { 187 | Ok(unix_stream) => UnixStream::from_std(unix_stream)?, 188 | Err(error) => { 189 | racoon_error!("Failed to clone std unix stream."); 190 | let shutdown_result = std_unix_stream.shutdown(Shutdown::Both); 191 | racoon_debug!("Shutdown result: {:?}", shutdown_result); 192 | return Err(error); 193 | } 194 | }; 195 | 196 | let async_writer_rw = UnixStream::from_std(std_unix_stream)?; 197 | let (reader, writer) = tokio::io::split(async_writer_rw); 198 | 199 | Ok(Self { 200 | stream: Arc::new(Mutex::new(async_unix_stream)), 201 | reader: Arc::new(Mutex::new(reader)), 202 | writer: Arc::new(Mutex::new(writer)), 203 | buffer_size, 204 | restored_payload: Arc::new(Mutex::new(None)), 205 | }) 206 | } 207 | } 208 | 209 | impl AbstractStream for UnixStreamWrapper { 210 | fn buffer_size(&self) -> StreamResult { 211 | let buffer_size = self.buffer_size.clone(); 212 | Box::new(Box::pin(async move { buffer_size })) 213 | } 214 | 215 | fn peer_addr(&self) -> StreamResult> { 216 | Box::new(Box::pin(async move { 217 | return None; 218 | })) 219 | } 220 | 221 | fn restore_payload(&self, bytes: &[u8]) -> StreamResult> { 222 | let restored_payload = self.restored_payload.clone(); 223 | let bytes = bytes.to_vec(); 224 | 225 | Box::new(Box::pin(async move { 226 | let restored_payload_ref = restored_payload.clone(); 227 | let mut restored_payload = restored_payload_ref.lock().await; 228 | *restored_payload = Some(bytes); 229 | Ok(()) 230 | })) 231 | } 232 | 233 | fn restored_len(&self) -> StreamResult { 234 | let restored_payload = self.restored_payload.clone(); 235 | 236 | Box::new(Box::pin(async move { 237 | let restored_payload_ref = restored_payload.clone(); 238 | let restored_payload = restored_payload_ref.lock().await; 239 | 240 | if let Some(restored) = restored_payload.as_ref() { 241 | return restored.len(); 242 | } 243 | 244 | 0 245 | })) 246 | } 247 | 248 | fn read_chunk(&self) -> StreamResult>> { 249 | // If payload of some bytes is restored after reading the chunk, returns the same bytes 250 | // back to the reader again. 251 | // Reading from stream wrapper is skipped because there may not be any bytes to read. 252 | let restored_payload_ref = self.restored_payload.clone(); 253 | let buffer_size = self.buffer_size.clone(); 254 | 255 | let reader = self.reader.clone(); 256 | 257 | Box::new(Box::pin(async move { 258 | let mut restored_payload = restored_payload_ref.lock().await; 259 | 260 | if let Some(payload) = restored_payload.as_ref() { 261 | let buffer = payload.to_owned(); 262 | *restored_payload = None; 263 | return Ok(buffer); 264 | } 265 | 266 | let mut buffer = vec![0u8; buffer_size]; 267 | 268 | let reader_ref = reader.clone(); 269 | let mut reader = reader_ref.lock().await; 270 | 271 | return match reader.read(&mut buffer).await { 272 | Ok(read_size) => { 273 | if read_size == 0 { 274 | return Err(std::io::Error::new( 275 | ErrorKind::BrokenPipe, 276 | "Read size is 0. Probably connection broken.", 277 | )); 278 | } 279 | 280 | let chunk = &buffer[0..read_size]; 281 | Ok(chunk.to_vec()) 282 | } 283 | Err(error) => Err(std::io::Error::other(error)), 284 | }; 285 | })) 286 | } 287 | 288 | fn write_chunk(&self, data: &[u8]) -> StreamResult> { 289 | let writer_ref = self.writer.clone(); 290 | let data = data.to_vec(); 291 | 292 | Box::new(Box::pin(async move { 293 | let mut writer = writer_ref.lock().await; 294 | writer.write_all(&data).await?; 295 | Ok(()) 296 | })) 297 | } 298 | 299 | fn shutdown(&self) -> StreamResult> { 300 | let stream_ref = self.stream.clone(); 301 | 302 | Box::new(Box::pin(async move { 303 | let mut stream = stream_ref.lock().await; 304 | let _ = stream.shutdown().await; 305 | Ok(()) 306 | })) 307 | } 308 | } 309 | 310 | #[derive(Debug)] 311 | pub struct TlsTcpStreamWrapper { 312 | peer_addr: SocketAddr, 313 | stream: Arc>, 314 | reader: Arc>>>, 315 | writer: Arc>>>, 316 | buffer_size: usize, 317 | restored_payload: Arc>>>, 318 | } 319 | 320 | impl TlsTcpStreamWrapper { 321 | pub async fn from( 322 | tcp_stream: TcpStream, 323 | tls_acceptor: &TlsAcceptor, 324 | buffer_size: usize, 325 | ) -> std::io::Result { 326 | let peer_addr = tcp_stream.peer_addr()?; 327 | let std_tcp_stream = tcp_stream.into_std()?; 328 | 329 | // Stream for shutting down reader/writer later 330 | let stream = TcpStream::from_std(std_tcp_stream.try_clone()?)?; 331 | let async_reader = TcpStream::from_std(std_tcp_stream)?; 332 | 333 | let tls_async_stream = tls_acceptor.accept(async_reader).await?; 334 | let (reader, writer) = tokio::io::split(tls_async_stream); 335 | 336 | Ok(Self { 337 | peer_addr, 338 | stream: Arc::new(Mutex::new(stream)), 339 | reader: Arc::new(Mutex::new(reader)), 340 | writer: Arc::new(Mutex::new(writer)), 341 | buffer_size, 342 | restored_payload: Arc::new(Mutex::new(None)), 343 | }) 344 | } 345 | } 346 | 347 | impl AbstractStream for TlsTcpStreamWrapper { 348 | fn buffer_size(&self) -> StreamResult { 349 | let buffer_size = self.buffer_size.clone(); 350 | Box::new(Box::pin(async move { buffer_size })) 351 | } 352 | 353 | fn peer_addr(&self) -> StreamResult> { 354 | let peer_addr = self.peer_addr.clone(); 355 | 356 | Box::new(Box::pin(async move { Some(peer_addr) })) 357 | } 358 | 359 | fn restore_payload(&self, bytes: &[u8]) -> StreamResult> { 360 | let restored_payload_ref = self.restored_payload.clone(); 361 | 362 | let bytes = bytes.to_vec(); 363 | 364 | Box::new(Box::pin(async move { 365 | let mut restored_payload = restored_payload_ref.lock().await; 366 | *restored_payload = Some(bytes); 367 | Ok(()) 368 | })) 369 | } 370 | 371 | fn restored_len(&self) -> StreamResult { 372 | let restored_payload_ref = self.restored_payload.clone(); 373 | 374 | Box::new(Box::pin(async move { 375 | let restored_payload = restored_payload_ref.lock().await; 376 | 377 | if let Some(restored) = restored_payload.as_ref() { 378 | return restored.len(); 379 | } 380 | 381 | 0 382 | })) 383 | } 384 | 385 | fn read_chunk(&self) -> StreamResult>> { 386 | // If payload of some bytes is restored after reading the chunk, returns the same bytes 387 | // back to the reader again. 388 | // Reading from stream wrapper is skipped because there may not be any bytes to read. 389 | let restored_payload_ref = self.restored_payload.clone(); 390 | let buffer_size = self.buffer_size.clone(); 391 | let reader = self.reader.clone(); 392 | 393 | Box::new(Box::pin(async move { 394 | let mut restored_payload = restored_payload_ref.lock().await; 395 | 396 | if let Some(payload) = restored_payload.as_ref() { 397 | let buffer = payload.to_owned(); 398 | *restored_payload = None; 399 | return Ok(buffer); 400 | } 401 | 402 | let mut buffer = vec![0u8; buffer_size]; 403 | let mut reader = reader.lock().await; 404 | 405 | return match reader.read(&mut buffer).await { 406 | Ok(read_size) => { 407 | if read_size == 0 { 408 | return Err(std::io::Error::new( 409 | ErrorKind::BrokenPipe, 410 | "Read size is 0. Probably connection broken.", 411 | )); 412 | } 413 | 414 | let chunk = &buffer[0..read_size]; 415 | Ok(chunk.to_vec()) 416 | } 417 | Err(error) => Err(std::io::Error::other(error)), 418 | }; 419 | })) 420 | } 421 | 422 | fn write_chunk(&self, data: &[u8]) -> StreamResult> { 423 | let writer_ref = self.writer.clone(); 424 | let data = data.to_vec(); 425 | 426 | Box::new(Box::pin(async move { 427 | let mut writer = writer_ref.lock().await; 428 | writer.write_all(&data).await?; 429 | Ok(()) 430 | })) 431 | } 432 | 433 | fn shutdown(&self) -> StreamResult> { 434 | let stream_ref = self.stream.clone(); 435 | 436 | Box::new(Box::pin(async move { 437 | let mut stream = stream_ref.lock().await; 438 | stream.shutdown().await?; 439 | Ok(()) 440 | })) 441 | } 442 | } 443 | 444 | pub struct TestStreamWrapper { 445 | test_data: Arc>>, 446 | buffer_size: usize, 447 | is_shutdown: Arc, 448 | restored_payload: Arc>>>, 449 | } 450 | 451 | impl TestStreamWrapper { 452 | pub fn new(test_data: Vec, buffer_size: usize) -> Self { 453 | Self { 454 | test_data: Arc::new(Mutex::new(test_data)), 455 | buffer_size, 456 | is_shutdown: Arc::new(AtomicBool::new(false)), 457 | restored_payload: Arc::new(Mutex::new(None)), 458 | } 459 | } 460 | } 461 | 462 | impl AbstractStream for TestStreamWrapper { 463 | fn buffer_size(&self) -> StreamResult { 464 | Box::new(Box::pin(async move { self.buffer_size.clone() })) 465 | } 466 | 467 | fn peer_addr(&self) -> StreamResult> { 468 | Box::new(Box::pin(async move { None })) 469 | } 470 | 471 | fn shutdown(&self) -> StreamResult> { 472 | self.is_shutdown.store(true, Ordering::Relaxed); 473 | Box::new(Box::pin(async move { Ok(()) })) 474 | } 475 | 476 | fn write_chunk(&self, _: &[u8]) -> StreamResult> { 477 | Box::new(Box::pin(async move { 478 | if self.is_shutdown.load(Ordering::Relaxed) { 479 | return Err(std::io::Error::other( 480 | "Test Stream is already shutdown. Failed to write chunk.", 481 | )); 482 | } 483 | Ok(()) 484 | })) 485 | } 486 | 487 | fn read_chunk(&self) -> StreamResult>> { 488 | Box::new(Box::pin(async move { 489 | let restored_payload_ref = self.restored_payload.clone(); 490 | let mut restored_payload = restored_payload_ref.lock().await; 491 | 492 | // Reads bytes from restored payload if any. 493 | if let Some(restored_bytes) = restored_payload.take() { 494 | if restored_bytes.len() > 0 { 495 | return Ok(restored_bytes); 496 | } 497 | }; 498 | 499 | if self.is_shutdown.load(Ordering::Relaxed) { 500 | return Err(std::io::Error::other( 501 | "Test Stream is already shutdown. Failed to read chunk.", 502 | )); 503 | } 504 | 505 | let test_data_ref = self.test_data.clone(); 506 | let mut test_data = test_data_ref.lock().await; 507 | 508 | // Reads bytes from test data 509 | let read_size = std::cmp::min(self.buffer_size, test_data.len()); 510 | if read_size == 0 { 511 | return Err(std::io::Error::other("No bytes left to read.")); 512 | } 513 | 514 | let removed_bytes = test_data.drain(0..read_size).collect(); 515 | Ok(removed_bytes) 516 | })) 517 | } 518 | 519 | fn restored_len(&self) -> StreamResult { 520 | Box::new(Box::pin(async move { 521 | let restored_payload_ref = self.restored_payload.clone(); 522 | let restored_payload = restored_payload_ref.lock().await; 523 | 524 | if let Some(restored_payload) = restored_payload.as_ref() { 525 | return restored_payload.len(); 526 | } 527 | 528 | 0 529 | })) 530 | } 531 | 532 | fn restore_payload(&self, bytes: &[u8]) -> StreamResult> { 533 | let bytes = bytes.to_vec(); 534 | 535 | Box::new(Box::pin(async move { 536 | let restored_payload_ref = self.restored_payload.clone(); 537 | let mut restored_payload = restored_payload_ref.lock().await; 538 | *restored_payload = Some(bytes); 539 | Ok(()) 540 | })) 541 | } 542 | } 543 | -------------------------------------------------------------------------------- /src/forms/fields/input_field.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::marker::PhantomData; 3 | use std::sync::atomic::AtomicBool; 4 | use std::sync::atomic::Ordering; 5 | use std::sync::Arc; 6 | 7 | use tokio::sync::Mutex; 8 | 9 | use crate::core::forms::{Files, FormData}; 10 | 11 | use crate::forms::fields::FieldResult; 12 | use crate::forms::AbstractFields; 13 | 14 | pub enum InputFieldError<'a> { 15 | MissingField(&'a String), 16 | /// (field_name, value, minimum_length) 17 | MinimumLengthRequired(&'a String, &'a String, &'a usize), 18 | /// (field_name, value, maximum_length) 19 | MaximumLengthExceed(&'a String, &'a String, &'a usize), 20 | } 21 | 22 | pub type PostValidator = Box Result>>; 23 | pub type ErrorHandler = Box) -> Vec>; 24 | 25 | pub trait ToOptionT { 26 | fn from_vec(value: &mut Vec) -> Option 27 | where 28 | Self: Sized; 29 | fn is_optional() -> bool; 30 | } 31 | 32 | impl ToOptionT for String { 33 | fn from_vec(values: &mut Vec) -> Option { 34 | if values.len() > 0 { 35 | return Some(values.remove(0)); 36 | } 37 | 38 | // Here None denotes values cannot be correctly converted to type T. 39 | None 40 | } 41 | 42 | fn is_optional() -> bool { 43 | false 44 | } 45 | } 46 | 47 | impl ToOptionT for Option { 48 | fn from_vec(values: &mut Vec) -> Option { 49 | if values.len() > 0 { 50 | let value = values.remove(0); 51 | return Some(Some(value)); 52 | } else { 53 | // Here outer Some denotes values are correctly converted to type T with value None. 54 | // Since fields are missing, default value is None. 55 | return Some(None); 56 | } 57 | } 58 | 59 | fn is_optional() -> bool { 60 | true 61 | } 62 | } 63 | 64 | impl ToOptionT for Vec { 65 | fn from_vec(values: &mut Vec) -> Option { 66 | // At least one value must be present to be a required field. 67 | if values.len() > 0 { 68 | let mut owned_values = vec![]; 69 | 70 | for i in (0..values.len()).rev() { 71 | owned_values.insert(0, values.remove(i)); 72 | } 73 | 74 | return Some(owned_values); 75 | } 76 | 77 | // Here None denotes values cannot be correctly converted to type T. 78 | None 79 | } 80 | 81 | fn is_optional() -> bool { 82 | false 83 | } 84 | } 85 | 86 | impl ToOptionT for Option> { 87 | fn from_vec(values: &mut Vec) -> Option { 88 | // At least one value must be present to be a required field. 89 | if values.len() > 0 { 90 | let mut owned_values = vec![]; 91 | 92 | for i in (0..values.len()).rev() { 93 | owned_values.insert(0, values.remove(i)); 94 | } 95 | 96 | return Some(Some(owned_values)); 97 | } 98 | 99 | // Here no values are received but since it's optional field, 100 | // returns successfull conversion to type None. 101 | Some(None) 102 | } 103 | 104 | fn is_optional() -> bool { 105 | true 106 | } 107 | } 108 | 109 | type BoxResult = Box; 110 | 111 | pub struct InputField { 112 | field_name: String, 113 | /// Maximum allowed text size. 114 | max_length: Option>, 115 | /// Minimum length size for valid input field. 116 | min_length: Option>, 117 | /// Option enum holds the value of type T. 118 | result: Arc>>, 119 | /// Custom function callback for handling error. 120 | error_handler: Option>, 121 | /// Custom callback for post field validation. 122 | post_validator: Option>>, 123 | /// Default value if no form field value received. 124 | default_value: Option, 125 | /// True if validated successfully else false. 126 | validated: Arc, 127 | /// Dummy type for compile time and runtime check. 128 | phantom: PhantomData, 129 | } 130 | 131 | impl InputField { 132 | pub fn new>(field_name: S) -> Self { 133 | let field_name = field_name.as_ref().to_string(); 134 | 135 | Self { 136 | field_name, 137 | max_length: None, 138 | min_length: None, 139 | result: Arc::new(Mutex::new(None)), 140 | error_handler: None, 141 | post_validator: None, 142 | default_value: None, 143 | validated: Arc::new(AtomicBool::from(false)), 144 | phantom: PhantomData, 145 | } 146 | } 147 | 148 | pub fn max_length(mut self, max_length: usize) -> Self { 149 | self.max_length = Some(Arc::new(max_length)); 150 | self 151 | } 152 | 153 | pub fn min_length(mut self, min_length: usize) -> Self { 154 | self.min_length = Some(Arc::new(min_length)); 155 | self 156 | } 157 | 158 | pub fn set_default>(mut self, value: S) -> Self { 159 | let value = value.as_ref().to_string(); 160 | self.default_value = Some(value); 161 | self 162 | } 163 | 164 | pub fn post_validate(mut self, call: fn(t: T) -> Result>) -> Self { 165 | self.post_validator = Some(Arc::new(Box::new(call))); 166 | self 167 | } 168 | 169 | pub fn handle_error_message( 170 | mut self, 171 | callback: fn(InputFieldError, Vec) -> Vec, 172 | ) -> Self { 173 | let callback = Arc::new(Box::new(callback)); 174 | self.error_handler = Some(callback); 175 | self 176 | } 177 | 178 | pub async fn value(self) -> T { 179 | if !self.validated.load(Ordering::Relaxed) { 180 | panic!("This field is not validated. Please call form.validate() method before accessing value."); 181 | } 182 | 183 | let mut result_ref = self.result.lock().await; 184 | let result = result_ref.take(); 185 | 186 | if let Some(result) = result { 187 | match result.downcast::() { 188 | Ok(t) => { 189 | return *t; 190 | } 191 | 192 | _ => {} 193 | }; 194 | } 195 | 196 | panic!("Unexpected error. Bug in input_field.rs file."); 197 | } 198 | } 199 | fn validate_input_length( 200 | field_name: &String, 201 | values: &Vec, 202 | error_handler: Option>, 203 | max_length: Option>, 204 | min_length: Option>, 205 | errors: &mut Vec, 206 | ) { 207 | let value; 208 | if let Some(value_ref) = values.get(0) { 209 | value = value_ref; 210 | } else { 211 | return; 212 | } 213 | 214 | if let Some(max_length) = max_length { 215 | // Checks maximum value length constraints 216 | if value.len() > *max_length { 217 | let default_max_length_exceed_messsage = 218 | format!("Character length exceeds maximum size of {}", *max_length); 219 | 220 | if let Some(error_handler) = error_handler.clone() { 221 | let max_length_exceed_error = 222 | InputFieldError::MaximumLengthExceed(&value, &field_name, &max_length); 223 | 224 | let custom_errors = error_handler( 225 | max_length_exceed_error, 226 | vec![default_max_length_exceed_messsage], 227 | ); 228 | errors.extend(custom_errors); 229 | } else { 230 | errors.push(default_max_length_exceed_messsage); 231 | } 232 | } 233 | } 234 | 235 | if let Some(min_length) = min_length { 236 | // Checks maximum value length constraints 237 | if value.len() < *min_length { 238 | let default_max_length_exceed_messsage = 239 | format!("Text length is less then {}", *min_length); 240 | 241 | if let Some(error_handler) = error_handler.clone() { 242 | let max_length_exceed_error = 243 | InputFieldError::MinimumLengthRequired(&value, &field_name, &min_length); 244 | 245 | let custom_errors = error_handler( 246 | max_length_exceed_error, 247 | vec![default_max_length_exceed_messsage], 248 | ); 249 | errors.extend(custom_errors); 250 | } else { 251 | errors.push(default_max_length_exceed_messsage); 252 | } 253 | } 254 | } 255 | } 256 | 257 | impl Clone for InputField { 258 | fn clone(&self) -> Self { 259 | Self { 260 | field_name: self.field_name.clone(), 261 | max_length: self.max_length.clone(), 262 | min_length: self.min_length.clone(), 263 | error_handler: self.error_handler.clone(), 264 | post_validator: self.post_validator.clone(), 265 | result: self.result.clone(), 266 | default_value: self.default_value.clone(), 267 | validated: self.validated.clone(), 268 | phantom: self.phantom.clone(), 269 | } 270 | } 271 | } 272 | 273 | impl AbstractFields for InputField { 274 | fn field_name(&self) -> FieldResult { 275 | let field_name = self.field_name.clone(); 276 | Box::new(Box::pin(async move { field_name })) 277 | } 278 | 279 | fn validate( 280 | &mut self, 281 | form_data: &mut FormData, 282 | _: &mut Files, 283 | ) -> FieldResult>> { 284 | let field_name = self.field_name.clone(); 285 | 286 | let mut form_values; 287 | 288 | // Takes value from form field 289 | if let Some(values) = form_data.remove(&field_name) { 290 | form_values = Some(values); 291 | } else { 292 | form_values = None; 293 | } 294 | 295 | let max_length = self.max_length.clone(); 296 | let min_length = self.min_length.clone(); 297 | let default_value = self.default_value.take(); 298 | let validated = self.validated.clone(); 299 | let result = self.result.clone(); 300 | 301 | let error_handler = self.error_handler.clone(); 302 | let post_validator = self.post_validator.clone(); 303 | 304 | Box::new(Box::pin(async move { 305 | let mut errors: Vec = vec![]; 306 | 307 | let is_empty; 308 | if let Some(values) = form_values.as_mut() { 309 | validate_input_length( 310 | &field_name, 311 | &values, 312 | error_handler.clone(), 313 | max_length, 314 | min_length, 315 | &mut errors, 316 | ); 317 | 318 | is_empty = values.is_empty(); 319 | } else { 320 | is_empty = true; 321 | } 322 | 323 | // Handles field missing error. 324 | let is_optional = T::is_optional(); 325 | 326 | if !is_optional && is_empty { 327 | // If default value is specified, set default value for value 328 | if let Some(default_value) = default_value { 329 | if is_empty { 330 | form_values = Some(vec![default_value]); 331 | } 332 | } else { 333 | let default_field_missing_error = "This field is missing.".to_string(); 334 | 335 | if let Some(error_handler) = error_handler { 336 | let field_missing_error = InputFieldError::MissingField(&field_name); 337 | let custom_errors = 338 | error_handler(field_missing_error, vec![default_field_missing_error]); 339 | errors.extend(custom_errors); 340 | } else { 341 | errors.push(default_field_missing_error); 342 | } 343 | } 344 | } 345 | 346 | if errors.len() > 0 { 347 | return Err(errors); 348 | } 349 | 350 | // All the validation conditions are satisfied. 351 | { 352 | let mut result_lock = result.lock().await; 353 | if let Some(values) = form_values.as_mut() { 354 | let value_t = T::from_vec(values); 355 | if let Some(mut t) = value_t { 356 | if let Some(post_validator) = post_validator { 357 | // Performs post validation callback. 358 | match post_validator(t) { 359 | Ok(post_validated_t) => { 360 | t = post_validated_t; 361 | *result_lock = Some(Box::new(t)); 362 | } 363 | Err(custom_errors) => { 364 | return Err(custom_errors); 365 | } 366 | } 367 | } else { 368 | *result_lock = Some(Box::new(t)); 369 | }; 370 | } 371 | } else { 372 | // Above conditions are satisfied however there are no values stored. 373 | // Probably Optional type without default value. 374 | let value_t = T::from_vec(&mut vec![]); 375 | *result_lock = Some(Box::new(value_t.unwrap())); 376 | } 377 | } 378 | 379 | validated.store(true, Ordering::Relaxed); 380 | Ok(()) 381 | })) 382 | } 383 | 384 | fn wrap(&self) -> Box { 385 | Box::new(self.clone()) 386 | } 387 | } 388 | 389 | #[cfg(test)] 390 | pub mod test { 391 | use crate::core::forms::{Files, FormData}; 392 | use crate::forms::fields::AbstractFields; 393 | 394 | use super::InputField; 395 | 396 | #[tokio::test] 397 | async fn test_validate_default() { 398 | let mut form_data = FormData::new(); 399 | let mut files = Files::new(); 400 | 401 | let mut input_field: InputField = 402 | InputField::new("name").set_default("John").max_length(100); 403 | let result = input_field.validate(&mut form_data, &mut files).await; 404 | assert_eq!(true, result.is_ok()); 405 | 406 | let value = input_field.value().await; 407 | assert_eq!(value, "John"); 408 | } 409 | 410 | #[tokio::test] 411 | async fn test_validate_string() { 412 | let mut form_data = FormData::new(); 413 | form_data.insert("name".to_string(), vec!["John".to_string()]); 414 | 415 | let mut files = Files::new(); 416 | 417 | let mut input_field: InputField = InputField::new("name").max_length(100); 418 | let result = input_field.validate(&mut form_data, &mut files).await; 419 | assert_eq!(true, result.is_ok()); 420 | 421 | let value = input_field.value().await; 422 | assert_eq!(value, "John"); 423 | } 424 | 425 | #[tokio::test] 426 | async fn test_validate_optional() { 427 | let mut form_data = FormData::new(); 428 | let mut files = Files::new(); 429 | 430 | let mut input_field: InputField> = InputField::new("name").max_length(100); 431 | let result = input_field.validate(&mut form_data, &mut files).await; 432 | assert_eq!(true, result.is_ok()); 433 | 434 | let value = input_field.value().await; 435 | assert_eq!(value, None); 436 | 437 | // With values 438 | form_data.insert("name".to_string(), vec!["John".to_string()]); 439 | let mut input_field2: InputField> = InputField::new("name").max_length(100); 440 | let result = input_field2.validate(&mut form_data, &mut files).await; 441 | assert_eq!(true, result.is_ok()); 442 | assert_eq!(Some("John".to_string()), input_field2.value().await); 443 | } 444 | 445 | #[tokio::test] 446 | async fn test_validate_vec() { 447 | let mut form_data = FormData::new(); 448 | let mut files = Files::new(); 449 | 450 | let mut input_field: InputField> = InputField::new("name").max_length(100); 451 | let result = input_field.validate(&mut form_data, &mut files).await; 452 | assert_eq!(false, result.is_ok()); 453 | 454 | // With values 455 | let mut input_field2: InputField> = InputField::new("name").max_length(100); 456 | 457 | form_data.insert( 458 | "name".to_string(), 459 | vec![ 460 | "1".to_string(), 461 | "2".to_string(), 462 | "3".to_string(), 463 | "4".to_string(), 464 | ], 465 | ); 466 | 467 | let result = input_field2.validate(&mut form_data, &mut files).await; 468 | assert_eq!(true, result.is_ok()); 469 | assert_eq!(4, input_field2.value().await.len()); 470 | } 471 | 472 | #[tokio::test] 473 | async fn test_validate_vec_optional() { 474 | let mut form_data = FormData::new(); 475 | let mut files = Files::new(); 476 | 477 | let mut input_field: InputField>> = 478 | InputField::new("name").max_length(100); 479 | let result = input_field.validate(&mut form_data, &mut files).await; 480 | assert_eq!(true, result.is_ok()); 481 | assert_eq!(false, input_field.value().await.is_some()); 482 | 483 | // With values 484 | let mut input_field2: InputField>> = 485 | InputField::new("name").max_length(100); 486 | 487 | form_data.insert( 488 | "name".to_string(), 489 | vec![ 490 | "1".to_string(), 491 | "2".to_string(), 492 | "3".to_string(), 493 | "4".to_string(), 494 | ], 495 | ); 496 | 497 | let result = input_field2.validate(&mut form_data, &mut files).await; 498 | assert_eq!(true, result.is_ok()); 499 | 500 | let value = input_field2.value().await; 501 | assert_eq!(true, value.is_some()); 502 | assert_eq!(4, value.unwrap().len()); 503 | } 504 | 505 | #[tokio::test] 506 | async fn test_value_length() { 507 | // Validate long text 508 | let mut input_field: InputField = InputField::new("name").max_length(10); 509 | let mut form_data = FormData::new(); 510 | 511 | const LONG_PARAGRAPH: &str = r#" 512 | Lorem ipsum dolor sit amet, qui minim labore adipisicing minim sint cillum sint consectetur cupidatat. 513 | "#; 514 | form_data.insert("name".to_string(), vec![LONG_PARAGRAPH.to_string()]); 515 | 516 | let mut files = Files::new(); 517 | let result = input_field.validate(&mut form_data, &mut files).await; 518 | assert_eq!(false, result.is_ok()); 519 | 520 | // Validate long text 521 | let mut input_field2: InputField = InputField::new("name").min_length(100); 522 | let mut form_data = FormData::new(); 523 | 524 | const SHORT_PARAGRAPH: &str = r#" 525 | Lorem ipsum dolor sit amet. 526 | "#; 527 | form_data.insert("name".to_string(), vec![SHORT_PARAGRAPH.to_string()]); 528 | 529 | let mut files = Files::new(); 530 | let result = input_field2.validate(&mut form_data, &mut files).await; 531 | assert_eq!(false, result.is_ok()); 532 | } 533 | 534 | #[tokio::test] 535 | async fn test_empty_value_with_length() { 536 | let mut input_field: InputField = InputField::new("name").max_length(100); 537 | let mut form_data = FormData::new(); 538 | let mut files = Files::new(); 539 | let result = input_field.validate(&mut form_data, &mut files).await; 540 | assert_eq!(false, result.is_ok()); 541 | } 542 | 543 | #[tokio::test] 544 | async fn test_post_validate() { 545 | let mut input_field: InputField = InputField::new("name") 546 | .max_length(100) 547 | .post_validate(|value| { 548 | if !value.eq("John") { 549 | return Err(vec!["Value is not John".to_string()]); 550 | } 551 | 552 | Ok(value) 553 | }); 554 | let mut form_data = FormData::new(); 555 | form_data.insert("name".to_string(), vec!["Raphel".to_string()]); 556 | 557 | let mut files = Files::new(); 558 | let result = input_field.validate(&mut form_data, &mut files).await; 559 | assert_eq!(false, result.is_ok()); 560 | } 561 | } 562 | -------------------------------------------------------------------------------- /src/core/parser/multipart.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use async_tempfile::TempFile; 4 | use regex::Regex; 5 | use tokio::io::AsyncWriteExt; 6 | 7 | use crate::core::headers; 8 | use crate::core::headers::{HeaderValue, Headers}; 9 | 10 | use crate::core::stream::Stream; 11 | 12 | use crate::core::forms::{FileField, Files, FormConstraints, FormData, FormFieldError}; 13 | 14 | #[derive(Debug)] 15 | pub struct FormPart { 16 | pub name: Option, 17 | pub value: Option, 18 | pub filename: Option, 19 | pub content_type: Option, 20 | pub file: Option, 21 | } 22 | 23 | pub struct MultipartParser { 24 | stream: Arc, 25 | form_constraints: Arc, 26 | boundary: String, 27 | allow_next_header_read: bool, 28 | first_header_scanned: bool, 29 | } 30 | 31 | impl MultipartParser { 32 | pub fn from( 33 | stream: Arc, 34 | headers: &Headers, 35 | form_constraints: Arc, 36 | ) -> std::io::Result { 37 | let content_type; 38 | if let Some(value) = headers.value("content-type") { 39 | content_type = value; 40 | } else { 41 | return Err(std::io::Error::other("Content-Type header is missing.")); 42 | } 43 | 44 | let boundary = headers::multipart_boundary(&content_type)?; 45 | 46 | Ok(MultipartParser { 47 | stream, 48 | form_constraints, 49 | boundary, 50 | allow_next_header_read: true, 51 | first_header_scanned: false, 52 | }) 53 | } 54 | 55 | pub async fn parse( 56 | stream: Arc, 57 | form_constraints: Arc, 58 | headers: &Headers, 59 | ) -> Result<(FormData, Files), FormFieldError> { 60 | let mut parser = match MultipartParser::from(stream, headers, form_constraints) { 61 | Ok(parser) => parser, 62 | Err(error) => { 63 | return Err(FormFieldError::Others(None, error.to_string(), true)); 64 | } 65 | }; 66 | 67 | let mut form_data = FormData::new(); 68 | let mut files = Files::new(); 69 | 70 | loop { 71 | let mut form_part = parser.next_form_header().await?; 72 | let parsing_completed = parser.next_form_value(&mut form_part).await?; 73 | 74 | let field_name; 75 | if let Some(value) = form_part.name { 76 | field_name = value; 77 | } else { 78 | return Err(FormFieldError::Others( 79 | None, 80 | "Field name is missing.".to_owned(), 81 | true, 82 | )); 83 | } 84 | 85 | if let Some(filename) = form_part.filename { 86 | let named_temp_file; 87 | if let Some(file) = form_part.file { 88 | named_temp_file = file; 89 | } else { 90 | return Err(FormFieldError::Others( 91 | Some(field_name.clone()), 92 | "Parsing error: file is missing.".to_owned(), 93 | true, 94 | )); 95 | } 96 | 97 | let temp_file = FileField::from(filename, named_temp_file); 98 | if let Some(files) = files.get_mut(&field_name) { 99 | files.push(temp_file); 100 | } else { 101 | files.insert(field_name, vec![temp_file]); 102 | } 103 | } else { 104 | if let Some(field_value) = form_part.value { 105 | if let Some(values) = form_data.get_mut(&field_name) { 106 | values.push(field_value); 107 | } else { 108 | form_data.insert(field_name, vec![field_value]); 109 | } 110 | } 111 | } 112 | 113 | if parsing_completed { 114 | return Ok((form_data, files)); 115 | } 116 | } 117 | } 118 | 119 | pub async fn next_form_header(&mut self) -> Result { 120 | if !self.allow_next_header_read { 121 | return Err(FormFieldError::Others( 122 | None, 123 | "Form part body not read.".to_string(), 124 | true, 125 | )); 126 | } 127 | 128 | let stream = self.stream.clone(); 129 | let max_header_size = self 130 | .form_constraints 131 | .max_header_size(stream.buffer_size().await); 132 | let scan_boundary = format!("--{}\r\n", &self.boundary); 133 | let scan_boundary_bytes = scan_boundary.as_bytes(); 134 | 135 | let mut buffer = vec![]; 136 | let mut bytes_read = 0; 137 | 138 | // Removes starting header for easier pattern matching 139 | if !self.first_header_scanned { 140 | // Fetches minimum bytes equal to scan boundary length 141 | loop { 142 | if buffer.len() >= scan_boundary.len() { 143 | break; 144 | } 145 | 146 | let chunk = match stream.read_chunk().await { 147 | Ok(bytes) => bytes, 148 | Err(error) => { 149 | return Err(FormFieldError::Others(None, error.to_string(), true)); 150 | } 151 | }; 152 | bytes_read += chunk.len(); 153 | buffer.extend(chunk); 154 | } 155 | 156 | if !buffer.starts_with(scan_boundary_bytes) { 157 | return Err(FormFieldError::Others( 158 | None, 159 | format!("Boundary does not start with {}", scan_boundary), 160 | true, 161 | )); 162 | } 163 | 164 | // Removes scan boundary bytes from buffer 165 | // Contains only form part header 166 | buffer.drain(0..scan_boundary.len()); 167 | self.first_header_scanned = true; 168 | } 169 | 170 | const FORM_PART_HEADER_TERMINATOR: &[u8; 4] = b"\r\n\r\n"; 171 | 172 | loop { 173 | if bytes_read > max_header_size { 174 | return Err(FormFieldError::MaxHeaderSizeExceed); 175 | } 176 | 177 | let scan_result = buffer 178 | .windows(FORM_PART_HEADER_TERMINATOR.len()) 179 | .position(|window| window == FORM_PART_HEADER_TERMINATOR); 180 | 181 | if let Some(position) = scan_result { 182 | let form_part_header_bytes = &buffer[..position]; 183 | let restore_bytes = &buffer[position + FORM_PART_HEADER_TERMINATOR.len()..]; 184 | let _ = stream.restore_payload(restore_bytes.as_ref()).await; 185 | 186 | // Deny next time calling this method because form part body also must be read. 187 | self.allow_next_header_read = false; 188 | return Ok(parse_form_part_header(form_part_header_bytes)?); 189 | } else { 190 | // Still form part not found. Collect more bytes. 191 | let chunk = match stream.read_chunk().await { 192 | Ok(bytes) => bytes, 193 | Err(error) => { 194 | return Err(FormFieldError::Others(None, error.to_string(), true)); 195 | } 196 | }; 197 | bytes_read += chunk.len(); 198 | buffer.extend(chunk); 199 | } 200 | } 201 | } 202 | 203 | pub async fn next_form_value( 204 | &mut self, 205 | form_part: &mut FormPart, 206 | ) -> Result { 207 | if self.allow_next_header_read { 208 | return Err(FormFieldError::Others( 209 | None, 210 | "Form part header is not read.".to_owned(), 211 | true, 212 | )); 213 | } 214 | 215 | if form_part.filename.is_some() { 216 | Ok(self.parse_file(form_part).await?) 217 | } else { 218 | Ok(self.parse_value(form_part).await?) 219 | } 220 | } 221 | 222 | pub async fn parse_file(&mut self, form_part: &mut FormPart) -> Result { 223 | let form_constraints = self.form_constraints.clone(); 224 | let field_name; 225 | if let Some(value) = &form_part.name { 226 | field_name = value.to_owned(); 227 | } else { 228 | return Err(FormFieldError::Others( 229 | None, 230 | "Field name is missing".to_owned(), 231 | false, 232 | )); 233 | } 234 | 235 | // Form constraints 236 | let max_file_size = 237 | form_constraints.max_size_for_file(&field_name, self.stream.buffer_size().await); 238 | let mut bytes_read = 0; 239 | 240 | let value_terminator = format!("\r\n--{}", self.boundary); 241 | let value_terminator_bytes = value_terminator.as_bytes(); 242 | 243 | let mut temp_file = match TempFile::new().await { 244 | Ok(file) => match file.open_rw().await { 245 | Ok(result) => result, 246 | Err(error) => { 247 | return Err(FormFieldError::Others(None, error.to_string(), true)); 248 | } 249 | }, 250 | Err(error) => { 251 | return Err(FormFieldError::Others(None, error.to_string(), true)); 252 | } 253 | }; 254 | let mut scan_buffer = vec![]; 255 | const FORM_PART_END: &[u8; 4] = b"--\r\n"; 256 | const CRLF_BREAK: &[u8; 2] = b"\r\n"; 257 | 258 | loop { 259 | if bytes_read > max_file_size { 260 | return Err(FormFieldError::MaxFileSizeExceed(field_name.clone())); 261 | } 262 | 263 | let scan_result = scan_buffer 264 | .windows(value_terminator_bytes.len()) 265 | .position(|window| window == value_terminator_bytes); 266 | 267 | if let Some(matched_position) = scan_result { 268 | // File scan reached end 269 | // Extra offset to check whether file ends or not 270 | // If extra terminator byte offset is not present, it does not matter whether field 271 | // end is found or not. Can be scanned again. 272 | 273 | if scan_buffer.len() 274 | >= matched_position + value_terminator_bytes.len() + FORM_PART_END.len() 275 | { 276 | let to_copy_position = matched_position; 277 | let to_copy = &scan_buffer[..to_copy_position]; 278 | 279 | match temp_file.write_all(to_copy).await { 280 | Ok(()) => {} 281 | Err(error) => { 282 | return Err(FormFieldError::Others( 283 | Some(field_name.to_string()), 284 | format!("Failed to write file. Error: {}", error), 285 | true, 286 | )); 287 | } 288 | } 289 | 290 | let _ = temp_file.flush().await; 291 | 292 | scan_buffer = 293 | (&scan_buffer[to_copy_position + value_terminator_bytes.len()..]).to_vec(); 294 | return if &scan_buffer[..FORM_PART_END.len()] == FORM_PART_END { 295 | // Request body completed 296 | form_part.file = Some(temp_file); 297 | self.allow_next_header_read = true; 298 | Ok(true) 299 | } else { 300 | // Form part completed but body is not ended yet 301 | // Skips line break \r\n 302 | scan_buffer.drain(..CRLF_BREAK.len()); 303 | let _ = self.stream.restore_payload(&scan_buffer.as_ref()).await; 304 | form_part.file = Some(temp_file); 305 | self.allow_next_header_read = true; 306 | Ok(false) 307 | }; 308 | } 309 | } 310 | 311 | // Copy data 312 | if scan_buffer.len() > value_terminator_bytes.len() { 313 | // This much amount of bytes can be copied safely from the file buffer 314 | let to_copy_position = scan_buffer.len() - value_terminator_bytes.len(); 315 | 316 | match temp_file.write_all(&scan_buffer[..to_copy_position]).await { 317 | Ok(()) => {} 318 | Err(error) => { 319 | return Err(FormFieldError::Others( 320 | Some(field_name.to_string()), 321 | format!("Failed to write file. Error: {}", error), 322 | true, 323 | )); 324 | } 325 | } 326 | 327 | scan_buffer.drain(..to_copy_position); 328 | } 329 | 330 | // File ending has not been reached 331 | let chunk = match self.stream.read_chunk().await { 332 | Ok(bytes) => bytes, 333 | Err(error) => { 334 | return Err(FormFieldError::Others(None, error.to_string(), true)); 335 | } 336 | }; 337 | bytes_read += chunk.len(); 338 | scan_buffer.extend(chunk); 339 | } 340 | } 341 | 342 | pub async fn parse_value(&mut self, form_part: &mut FormPart) -> Result { 343 | let field_name; 344 | if let Some(value) = &form_part.name { 345 | field_name = value.to_owned(); 346 | } else { 347 | return Err(FormFieldError::Others( 348 | None, 349 | "Field name is missing.".to_owned(), 350 | false, 351 | )); 352 | } 353 | 354 | let max_value_size = self 355 | .form_constraints 356 | .max_size_for_field(&field_name, self.stream.buffer_size().await); 357 | let scan_boundary = format!("\r\n--{}", self.boundary); 358 | let scan_boundary_bytes = scan_boundary.as_bytes(); 359 | 360 | let mut buffer = vec![]; 361 | 362 | const FORM_PART_END: &[u8; 4] = b"--\r\n"; 363 | const CRLF_BREAK: &[u8; 2] = b"\r\n"; 364 | 365 | let mut bytes_read = 0; 366 | 367 | loop { 368 | if bytes_read > max_value_size { 369 | return Err(FormFieldError::MaxValueSizeExceed(field_name)); 370 | } 371 | let scan_result = buffer 372 | .windows(scan_boundary_bytes.len()) 373 | .position(|window| window == scan_boundary_bytes); 374 | 375 | if let Some(position) = scan_result { 376 | if buffer.len() >= position + scan_boundary_bytes.len() + FORM_PART_END.len() { 377 | let to_copy = &buffer[..position]; 378 | let mut to_copy_range = to_copy.len(); 379 | 380 | // Some clients sends single CRLF and some double CRLF line breaks 381 | if to_copy.len() > 1 382 | && &to_copy[..to_copy.len() - CRLF_BREAK.len()] == CRLF_BREAK 383 | { 384 | to_copy_range -= 1; 385 | } 386 | 387 | let value = String::from_utf8_lossy(&to_copy[..to_copy_range]).to_string(); 388 | 389 | // Removes copied bytes from the buffer 390 | buffer.drain(..position + scan_boundary_bytes.len()); 391 | form_part.value = Some(value); 392 | 393 | return if &buffer[..FORM_PART_END.len()] == FORM_PART_END { 394 | self.allow_next_header_read = true; 395 | Ok(true) 396 | } else { 397 | // Form part completed but body is not ended yet 398 | // Skips line break \r\n 399 | buffer.drain(..CRLF_BREAK.len()); 400 | let _ = self.stream.restore_payload(buffer.as_ref()).await; 401 | self.allow_next_header_read = true; 402 | Ok(false) 403 | }; 404 | } 405 | } 406 | 407 | let chunk = match self.stream.read_chunk().await { 408 | Ok(bytes) => bytes, 409 | Err(error) => { 410 | return Err(FormFieldError::Others(None, error.to_string(), true)); 411 | } 412 | }; 413 | bytes_read += chunk.len(); 414 | buffer.extend(chunk); 415 | } 416 | } 417 | } 418 | 419 | pub fn parse_form_part_header(header_bytes: &[u8]) -> Result { 420 | let mut last_scanned_position = 0; 421 | const HEADER_LINE_TERMINATOR: &[u8; 2] = b"\r\n"; 422 | 423 | let mut header_bytes = header_bytes.to_vec(); 424 | 425 | // Makes sure scan window reach upto last header line 426 | if !header_bytes.ends_with(b"\r\n") { 427 | header_bytes.extend(b"\r\n"); 428 | } 429 | 430 | let mut form_part = FormPart { 431 | name: None, 432 | filename: None, 433 | content_type: None, 434 | file: None, 435 | value: None, 436 | }; 437 | 438 | loop { 439 | let to_scan = &header_bytes[last_scanned_position..]; 440 | let scan_result = to_scan 441 | .windows(HEADER_LINE_TERMINATOR.len()) 442 | .position(|window| window == HEADER_LINE_TERMINATOR); 443 | 444 | if let Some(relative_position) = scan_result { 445 | // One header found 446 | let header_line = 447 | &header_bytes[last_scanned_position..last_scanned_position + relative_position]; 448 | match parse_form_part_header_line(header_line, &mut form_part) { 449 | Ok(()) => {} 450 | Err(error) => { 451 | return Err(FormFieldError::Others(None, error.to_string(), true)); 452 | } 453 | }; 454 | last_scanned_position += relative_position + HEADER_LINE_TERMINATOR.len(); 455 | } else { 456 | return Ok(form_part); 457 | } 458 | } 459 | } 460 | 461 | fn parse_form_part_header_line( 462 | header_line: &[u8], 463 | form_part: &mut FormPart, 464 | ) -> std::io::Result<()> { 465 | let header_line = String::from_utf8_lossy(header_line); 466 | let parts: Vec<&str> = header_line.splitn(2, ":").collect(); 467 | 468 | if parts.len() != 2 { 469 | return Ok(()); 470 | } 471 | 472 | let header_name; 473 | if let Some(name) = parts.get(0) { 474 | header_name = name.trim(); 475 | } else { 476 | return Err(std::io::Error::other("Header name is missing.")); 477 | } 478 | 479 | let header_value; 480 | if let Some(value) = parts.get(1) { 481 | header_value = *value; 482 | } else { 483 | return Err(std::io::Error::other("Header value is missing.")); 484 | } 485 | 486 | if header_name.to_lowercase() == "content-disposition" { 487 | parse_content_disposition_value(header_value, form_part)?; 488 | } else if header_name.to_lowercase() == "content-type" { 489 | form_part.content_type = Some(header_value.trim().to_string()); 490 | } 491 | Ok(()) 492 | } 493 | 494 | pub fn parse_content_disposition_value( 495 | value: &str, 496 | form_part: &mut FormPart, 497 | ) -> std::io::Result<()> { 498 | let value = value.trim(); 499 | 500 | if !value.starts_with("form-data;") { 501 | // Not a valid Content-Deposition value for form part header 502 | return Err(std::io::Error::other( 503 | "Not a valid Content-Deposition value for form part header", 504 | )); 505 | } 506 | 507 | let remaining = value.strip_prefix("form-data;").unwrap().trim(); 508 | let pattern = Regex::new(r#"(?\w+)="(?[^"]*)""#).unwrap(); 509 | 510 | // Goes through all attributes and values 511 | for captured in pattern.captures_iter(remaining) { 512 | let attribute = &captured["attribute"]; 513 | let value = &captured["value"]; 514 | 515 | if attribute == "name" { 516 | form_part.name = Some(value.to_string()); 517 | } else if attribute == "filename" { 518 | form_part.filename = Some(value.to_string()); 519 | } 520 | } 521 | 522 | if form_part.name.is_none() { 523 | return Err(std::io::Error::other( 524 | "Field name is missing in form part header.", 525 | )); 526 | } 527 | 528 | Ok(()) 529 | } 530 | 531 | #[cfg(test)] 532 | pub mod tests { 533 | use std::{collections::HashMap, sync::Arc}; 534 | 535 | use crate::core::forms::{FileFieldShortcut, FormConstraints}; 536 | use crate::core::headers::{HeaderValue, Headers}; 537 | use crate::core::shortcuts::SingleText; 538 | use crate::core::stream::{AbstractStream, TestStreamWrapper}; 539 | 540 | use super::MultipartParser; 541 | 542 | #[tokio::test] 543 | async fn test_multipart_parser() { 544 | let mut headers = Headers::new(); 545 | headers.set("Content-Type", "multipart/form-data; boundary=boundary123"); 546 | 547 | let test_data = "--boundary123\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\nJohn\r\n--boundary123\r\nContent-Disposition: form-data; name=\"location\"\r\n\r\nktm\r\n--boundary123\r\nContent-Disposition: form-data; name=\"file\"; filename=\"example.txt\"\r\nContent-Type: text/plain\r\n\r\nHello World\r\n--boundary123--\r\n".as_bytes().to_vec(); 548 | headers.set("Content-Length", test_data.len().to_string()); 549 | 550 | let stream: Box = Box::new(TestStreamWrapper::new(test_data, 1024)); 551 | 552 | let form_constraints = Arc::new(FormConstraints::new( 553 | 500 * 1024 * 1024, 554 | 2 * 1024 * 1024, 555 | 500 * 1024 * 1024, 556 | 2 * 1024 * 1024, 557 | HashMap::new(), 558 | )); 559 | 560 | let parser = MultipartParser::parse(Arc::new(stream), form_constraints, &headers).await; 561 | assert_eq!(true, parser.is_ok()); 562 | 563 | let (form_data, files) = parser.unwrap(); 564 | assert_eq!(Some(&"John".to_string()), form_data.value("name")); 565 | assert_eq!(Some(&"ktm".to_string()), form_data.value("location")); 566 | 567 | let file_field = files.value("file"); 568 | assert_eq!(true, file_field.is_some()); 569 | 570 | let file = file_field.unwrap(); 571 | let file_path = &file.temp_path; 572 | assert_eq!("example.txt".to_string(), file.name); 573 | 574 | let file_content = tokio::fs::read_to_string(&file_path).await.unwrap(); 575 | assert_eq!("Hello World".to_string(), file_content); 576 | } 577 | } 578 | --------------------------------------------------------------------------------