├── files ├── example.txt ├── example.yaml └── example.json ├── rust-toolchain ├── .github └── pull_request_template.md ├── src ├── transport_layer │ ├── transport_layer_type.rs │ ├── mod.rs │ ├── transport_layer.rs │ ├── transport_layer_builder.rs │ ├── into_transport_layer.rs │ └── into_transport_layer │ │ ├── axum_service.rs │ │ ├── shuttle_axum.rs │ │ ├── router.rs │ │ ├── into_make_service_with_connect_info.rs │ │ ├── serve.rs │ │ ├── with_graceful_shutdown.rs │ │ └── into_make_service.rs ├── internals │ ├── websockets │ │ ├── mod.rs │ │ ├── test_response_websocket.rs │ │ └── ws_key_generator.rs │ ├── transport_layer │ │ ├── mod.rs │ │ ├── http_transport_layer.rs │ │ └── mock_transport_layer.rs │ ├── with_this_mut.rs │ ├── mod.rs │ ├── expected_state.rs │ ├── status_code_formatter.rs │ ├── format_status_code_range.rs │ ├── try_into_range_bounds.rs │ ├── request_path_formatter.rs │ ├── starting_tcp_setup.rs │ ├── query_params_store.rs │ └── debug_response_body.rs ├── util │ ├── new_random_port.rs │ ├── mod.rs │ ├── new_random_socket_addr.rs │ ├── serve_handle.rs │ ├── new_random_tcp_listener.rs │ ├── spawn_serve.rs │ └── new_random_tokio_tcp_listener.rs ├── expect_json.rs ├── test_request │ └── test_request_config.rs ├── testing │ └── mod.rs ├── transport.rs ├── multipart │ ├── mod.rs │ ├── multipart_form.rs │ └── part.rs ├── test_server │ └── server_shared_state.rs ├── test_server_config.rs ├── test_server_builder.rs ├── lib.rs └── test_web_socket.rs ├── tests └── test-expect-json-integration.rs ├── .gitignore ├── examples ├── example-websocket-chat │ ├── README.md │ └── main.rs ├── example-websocket-ping-pong │ ├── README.md │ └── main.rs ├── example-todo │ ├── README.md │ └── main.rs └── example-shuttle │ ├── README.md │ └── main.rs ├── LICENSE ├── makefile ├── Cargo.toml └── README.md /files/example.txt: -------------------------------------------------------------------------------- 1 | hello! -------------------------------------------------------------------------------- /rust-toolchain: -------------------------------------------------------------------------------- 1 | stable -------------------------------------------------------------------------------- /files/example.yaml: -------------------------------------------------------------------------------- 1 | name: Joe 2 | age: 20 -------------------------------------------------------------------------------- /files/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Joe", 3 | "age": 20 4 | } 5 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Changes 2 | 3 | * 4 | 5 | # Comments 6 | 7 | Any other business. 8 | 9 | -------------------------------------------------------------------------------- /src/transport_layer/transport_layer_type.rs: -------------------------------------------------------------------------------- 1 | #[derive(Clone, Copy, PartialEq, Eq, Debug)] 2 | pub enum TransportLayerType { 3 | Http, 4 | Mock, 5 | } 6 | -------------------------------------------------------------------------------- /src/internals/websockets/mod.rs: -------------------------------------------------------------------------------- 1 | mod test_response_websocket; 2 | pub use self::test_response_websocket::*; 3 | 4 | mod ws_key_generator; 5 | pub use self::ws_key_generator::*; 6 | -------------------------------------------------------------------------------- /src/internals/transport_layer/mod.rs: -------------------------------------------------------------------------------- 1 | mod http_transport_layer; 2 | pub use self::http_transport_layer::*; 3 | 4 | mod mock_transport_layer; 5 | pub use self::mock_transport_layer::*; 6 | -------------------------------------------------------------------------------- /src/internals/websockets/test_response_websocket.rs: -------------------------------------------------------------------------------- 1 | use hyper::upgrade::OnUpgrade; 2 | 3 | use crate::transport_layer::TransportLayerType; 4 | 5 | #[derive(Debug, Clone)] 6 | pub struct TestResponseWebSocket { 7 | pub maybe_on_upgrade: Option, 8 | pub transport_type: TransportLayerType, 9 | } 10 | -------------------------------------------------------------------------------- /src/internals/websockets/ws_key_generator.rs: -------------------------------------------------------------------------------- 1 | use base64::Engine; 2 | use base64::engine::general_purpose::STANDARD; 3 | use uuid::Uuid; 4 | 5 | /// Generates a random key for use, that is base 64 encoded for use over HTTP. 6 | pub fn generate_ws_key() -> String { 7 | STANDARD.encode(Uuid::new_v4().as_bytes()) 8 | } 9 | -------------------------------------------------------------------------------- /src/util/new_random_port.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use anyhow::anyhow; 3 | use reserve_port::ReservedPort; 4 | 5 | /// Returns a randomly selected port that is not in use. 6 | pub fn new_random_port() -> Result { 7 | ReservedPort::random_permanently_reserved().map_err(|_| anyhow!("No free port was found")) 8 | } 9 | -------------------------------------------------------------------------------- /src/transport_layer/mod.rs: -------------------------------------------------------------------------------- 1 | mod into_transport_layer; 2 | pub use self::into_transport_layer::*; 3 | 4 | mod transport_layer_builder; 5 | pub use self::transport_layer_builder::*; 6 | 7 | mod transport_layer_type; 8 | pub use self::transport_layer_type::*; 9 | 10 | mod transport_layer; 11 | pub use self::transport_layer::*; 12 | -------------------------------------------------------------------------------- /tests/test-expect-json-integration.rs: -------------------------------------------------------------------------------- 1 | use axum_test::expect_json::expect_core::ExpectOp; 2 | use axum_test::expect_json::expect_core::expect_op; 3 | 4 | // If it compiles, it works! 5 | #[test] 6 | fn test_expect_op_for_axum_test_integration_compiles() { 7 | #[expect_op] 8 | #[derive(Debug, Clone)] 9 | pub struct Testing; 10 | 11 | impl ExpectOp for Testing {} 12 | 13 | assert!(true); 14 | } 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Cargo 2 | /target/ 3 | /examples/*/target 4 | Cargo.lock 5 | 6 | # Backup files generated by rustfmt 7 | **/*.rs.bk 8 | 9 | # MSVC Windows builds of rustc generate these, which store debugging information 10 | *.pdb 11 | 12 | # Local todo file 13 | /todo.txt 14 | /todo.md 15 | 16 | # MacOS 17 | .DS_Store 18 | 19 | *.swp 20 | 21 | # IDE folders 22 | .idea 23 | .vscode 24 | 25 | # Environment files (just in case) 26 | .env 27 | .env* 28 | -------------------------------------------------------------------------------- /src/util/mod.rs: -------------------------------------------------------------------------------- 1 | mod new_random_port; 2 | pub use self::new_random_port::*; 3 | 4 | mod new_random_socket_addr; 5 | pub use self::new_random_socket_addr::*; 6 | 7 | mod new_random_tcp_listener; 8 | pub use self::new_random_tcp_listener::*; 9 | 10 | mod new_random_tokio_tcp_listener; 11 | pub use self::new_random_tokio_tcp_listener::*; 12 | 13 | mod spawn_serve; 14 | pub use self::spawn_serve::*; 15 | 16 | mod serve_handle; 17 | pub use self::serve_handle::*; 18 | -------------------------------------------------------------------------------- /examples/example-websocket-chat/README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Example WebSockets Chat
4 |

5 | 6 |

7 | a simple chat application with tests 8 |

9 | 10 |
11 |
12 | 13 | This is a very simple application using WebSockets. It aims to show ... 14 | 15 | * How to write a very basic chat application, 16 | * and include tests which send and receive data. 17 | 18 | It's primarily to provide some code samples using axum-test. 19 | -------------------------------------------------------------------------------- /examples/example-websocket-ping-pong/README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Example WebSockets Ping Pong
4 |

5 | 6 |

7 | an example websocket application with tests 8 |

9 | 10 |
11 |
12 | 13 | This is a very simple application using WebSockets. It aims to show ... 14 | 15 | * How to write a basic test that starts a WebSocket connection. 16 | * A basic ping pong test, where data is pushed up and down. 17 | 18 | It's primarily to provide some code samples using axum-test. 19 | -------------------------------------------------------------------------------- /examples/example-todo/README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Example REST Todo
4 |

5 | 6 |

7 | an example application with tests 8 |

9 | 10 |
11 |
12 | 13 | This is a very simple todo application. It aims to show ... 14 | 15 | * How to write some basic tests against end points. 16 | * How to test for some tests to be expecting success, and some to be expecting failure. 17 | * How to take cookies into account (like logging in). 18 | 19 | It's primarily to provide some code samples using axum-test. 20 | -------------------------------------------------------------------------------- /examples/example-shuttle/README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Example REST Todo
4 |

5 | 6 |

7 | an example application with tests 8 |

9 | 10 |
11 |
12 | 13 | This is a very simple todo application. It aims to show ... 14 | 15 | * How to write some basic tests against end points. 16 | * How to test for some tests to be expecting success, and some to be expecting failure. 17 | * How to take cookies into account (like logging in). 18 | 19 | It's primarily to provide some code samples using axum-test. 20 | -------------------------------------------------------------------------------- /src/expect_json.rs: -------------------------------------------------------------------------------- 1 | pub use ::expect_json::expect::*; 2 | 3 | /// For implementing your own expectations. 4 | pub mod expect_core { 5 | pub use ::expect_json::expect_core::*; 6 | 7 | /// This macro is for defining your own custom [`ExpectOp`] checks. 8 | #[doc(inline)] 9 | pub use ::expect_json::expect_core::expect_op_for_axum_test as expect_op; 10 | 11 | pub use ::expect_json::ExpectJsonError; 12 | pub use ::expect_json::ExpectJsonResult; 13 | pub use ::expect_json::JsonType; 14 | } 15 | 16 | #[doc(hidden)] 17 | pub use ::expect_json::__private; 18 | -------------------------------------------------------------------------------- /src/util/new_random_socket_addr.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use std::net::IpAddr; 3 | use std::net::Ipv4Addr; 4 | use std::net::SocketAddr; 5 | 6 | use crate::util::new_random_port; 7 | 8 | pub(crate) const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); 9 | 10 | /// Generates a `SocketAddr` on the IP 127.0.0.1, using a random port. 11 | pub fn new_random_socket_addr() -> Result { 12 | let ip_address = DEFAULT_IP_ADDRESS; 13 | let port = new_random_port()?; 14 | let addr = SocketAddr::new(ip_address, port); 15 | 16 | Ok(addr) 17 | } 18 | -------------------------------------------------------------------------------- /src/internals/with_this_mut.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use anyhow::anyhow; 3 | use std::sync::Arc; 4 | use std::sync::Mutex; 5 | 6 | pub fn with_this_mut(this: &Arc>, name: &str, some_action: F) -> Result 7 | where 8 | F: FnOnce(&mut T) -> R, 9 | { 10 | let mut this_locked = this.lock().map_err(|err| { 11 | anyhow!( 12 | "Failed to lock InternalTestServer for `{}`, {:?}", 13 | name, 14 | err, 15 | ) 16 | })?; 17 | 18 | let result = some_action(&mut this_locked); 19 | 20 | Ok(result) 21 | } 22 | -------------------------------------------------------------------------------- /src/test_request/test_request_config.rs: -------------------------------------------------------------------------------- 1 | use cookie::CookieJar; 2 | use http::HeaderName; 3 | use http::HeaderValue; 4 | use http::Method; 5 | use url::Url; 6 | 7 | use crate::internals::ExpectedState; 8 | use crate::internals::QueryParamsStore; 9 | 10 | #[derive(Debug, Clone)] 11 | pub struct TestRequestConfig { 12 | pub is_saving_cookies: bool, 13 | pub expected_state: ExpectedState, 14 | pub content_type: Option, 15 | pub full_request_url: Url, 16 | pub method: Method, 17 | 18 | pub cookies: CookieJar, 19 | pub query_params: QueryParamsStore, 20 | pub headers: Vec<(HeaderName, HeaderValue)>, 21 | } 22 | -------------------------------------------------------------------------------- /src/util/serve_handle.rs: -------------------------------------------------------------------------------- 1 | use tokio::task::JoinHandle; 2 | 3 | /// A handle to a running Axum service. 4 | /// 5 | /// When the handle is dropped, it will attempt to terminate the service. 6 | #[derive(Debug)] 7 | pub struct ServeHandle { 8 | server_handle: JoinHandle<()>, 9 | } 10 | 11 | impl ServeHandle { 12 | pub(crate) fn new(server_handle: JoinHandle<()>) -> Self { 13 | Self { server_handle } 14 | } 15 | 16 | pub fn is_finished(&self) -> bool { 17 | self.server_handle.is_finished() 18 | } 19 | } 20 | 21 | impl Drop for ServeHandle { 22 | fn drop(&mut self) { 23 | self.server_handle.abort() 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/internals/mod.rs: -------------------------------------------------------------------------------- 1 | mod transport_layer; 2 | pub use self::transport_layer::*; 3 | 4 | #[cfg(feature = "ws")] 5 | mod websockets; 6 | #[cfg(feature = "ws")] 7 | pub use self::websockets::*; 8 | 9 | mod debug_response_body; 10 | pub use self::debug_response_body::*; 11 | 12 | mod expected_state; 13 | pub use self::expected_state::*; 14 | 15 | mod format_status_code_range; 16 | pub use self::format_status_code_range::*; 17 | 18 | mod status_code_formatter; 19 | pub use self::status_code_formatter::*; 20 | 21 | mod request_path_formatter; 22 | pub use self::request_path_formatter::*; 23 | 24 | mod query_params_store; 25 | pub use self::query_params_store::*; 26 | 27 | mod try_into_range_bounds; 28 | pub use self::try_into_range_bounds::*; 29 | 30 | mod starting_tcp_setup; 31 | pub use self::starting_tcp_setup::*; 32 | 33 | mod with_this_mut; 34 | pub use self::with_this_mut::*; 35 | -------------------------------------------------------------------------------- /src/util/new_random_tcp_listener.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use reserve_port::ReservedPort; 3 | use std::net::IpAddr; 4 | use std::net::Ipv4Addr; 5 | use std::net::SocketAddr; 6 | use std::net::TcpListener; 7 | 8 | pub(crate) const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); 9 | 10 | /// Binds a [`std::net::TcpListener`] on the IP 127.0.0.1, using a random port. 11 | /// 12 | /// This is the best way to pick a local port. 13 | pub fn new_random_tcp_listener() -> Result { 14 | let (tcp_listener, _) = ReservedPort::random_permanently_reserved_tcp(DEFAULT_IP_ADDRESS)?; 15 | Ok(tcp_listener) 16 | } 17 | 18 | /// Binds a [`std::net::TcpListener`] on the IP 127.0.0.1, using a random port. 19 | /// 20 | /// It is returned with the [`std::net::SocketAddr`] available. 21 | pub fn new_random_tcp_listener_with_socket_addr() -> Result<(TcpListener, SocketAddr)> { 22 | let result = ReservedPort::random_permanently_reserved_tcp(DEFAULT_IP_ADDRESS)?; 23 | Ok(result) 24 | } 25 | -------------------------------------------------------------------------------- /src/testing/mod.rs: -------------------------------------------------------------------------------- 1 | //! This contains helpers used in our tests. 2 | 3 | #[cfg(not(feature = "old-json-diff"))] 4 | use crate::expect_json::expect_core::Context; 5 | #[cfg(not(feature = "old-json-diff"))] 6 | use crate::expect_json::expect_core::ExpectOp; 7 | #[cfg(not(feature = "old-json-diff"))] 8 | use crate::expect_json::expect_core::ExpectOpResult; 9 | 10 | // This needs to be the external crate, as the `::axum_test` path doesn't work within our tests. 11 | #[cfg(not(feature = "old-json-diff"))] 12 | use ::expect_json::expect_core::expect_op; 13 | 14 | #[cfg(not(feature = "old-json-diff"))] 15 | #[expect_op] 16 | #[derive(Clone, Debug)] 17 | pub struct ExpectStrMinLen { 18 | pub min: usize, 19 | } 20 | 21 | #[cfg(not(feature = "old-json-diff"))] 22 | impl ExpectOp for ExpectStrMinLen { 23 | fn on_string(&self, _context: &mut Context<'_>, received: &str) -> ExpectOpResult<()> { 24 | if received.len() < self.min { 25 | panic!("String is too short, received: {received}"); 26 | } 27 | 28 | Ok(()) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/internals/expected_state.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, PartialEq, Clone, Copy, Eq, Hash)] 2 | pub enum ExpectedState { 3 | Success, 4 | Failure, 5 | None, 6 | } 7 | 8 | impl From> for ExpectedState { 9 | fn from(maybe_success: Option) -> Self { 10 | match maybe_success { 11 | None => Self::None, 12 | Some(true) => Self::Success, 13 | Some(false) => Self::Failure, 14 | } 15 | } 16 | } 17 | 18 | #[cfg(test)] 19 | mod test_from { 20 | use super::*; 21 | 22 | #[test] 23 | fn it_should_turn_none_to_none() { 24 | let output = ExpectedState::from(None); 25 | assert_eq!(output, ExpectedState::None); 26 | } 27 | 28 | #[test] 29 | fn it_should_turn_true_to_success() { 30 | let output = ExpectedState::from(Some(true)); 31 | assert_eq!(output, ExpectedState::Success); 32 | } 33 | 34 | #[test] 35 | fn it_should_turn_false_to_failure() { 36 | let output = ExpectedState::from(Some(false)); 37 | assert_eq!(output, ExpectedState::Failure); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Joseph Lenton 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/transport_layer/transport_layer.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::body::Body; 3 | use http::Request; 4 | use http::Response; 5 | use std::fmt::Debug; 6 | use std::future::Future; 7 | use std::pin::Pin; 8 | use url::Url; 9 | 10 | use crate::transport_layer::TransportLayerType; 11 | 12 | pub trait TransportLayer: Debug + Send + Sync + 'static { 13 | fn send<'a>( 14 | &'a self, 15 | request: Request, 16 | ) -> Pin>> + Send>>; 17 | 18 | fn url(&self) -> Option<&Url> { 19 | None 20 | } 21 | 22 | fn transport_layer_type(&self) -> TransportLayerType; 23 | 24 | fn is_running(&self) -> bool; 25 | } 26 | 27 | #[cfg(test)] 28 | mod test_sync { 29 | use super::*; 30 | use tokio::sync::OnceCell; 31 | 32 | #[test] 33 | fn it_should_compile_with_tokyo_once_cell() { 34 | // if it compiles, it works! 35 | fn _take_tokio_once_cell(layer: T) -> OnceCell> 36 | where 37 | T: TransportLayer, 38 | { 39 | OnceCell::new_with(Some(Box::new(layer))) 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/util/spawn_serve.rs: -------------------------------------------------------------------------------- 1 | use crate::util::ServeHandle; 2 | use axum::extract::Request; 3 | use axum::response::Response; 4 | use axum::serve; 5 | use axum::serve::IncomingStream; 6 | use axum::serve::Listener; 7 | use core::fmt::Debug; 8 | use std::convert::Infallible; 9 | use tokio::spawn; 10 | use tower::Service; 11 | 12 | /// A wrapper around [`axum::serve()`] for tests, 13 | /// which spawns the service in a new thread. 14 | /// 15 | /// The [`crate::util::ServeHandle`] returned will automatically attempt 16 | /// to terminate the service when dropped. 17 | pub fn spawn_serve(tcp_listener: L, make_service: M) -> ServeHandle 18 | where 19 | L: Listener, 20 | L::Addr: Debug, 21 | M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, 22 | for<'a> >>::Future: Send, 23 | S: Service + Clone + Send + 'static, 24 | S::Future: Send, 25 | { 26 | let server_handle = spawn(async move { 27 | serve(tcp_listener, make_service) 28 | .await 29 | .expect("Expect server to start serving"); 30 | }); 31 | 32 | ServeHandle::new(server_handle) 33 | } 34 | -------------------------------------------------------------------------------- /src/internals/status_code_formatter.rs: -------------------------------------------------------------------------------- 1 | use http::StatusCode; 2 | use std::fmt; 3 | 4 | #[derive(Debug, Copy, Clone, PartialEq)] 5 | pub struct StatusCodeFormatter(pub StatusCode); 6 | 7 | impl fmt::Display for StatusCodeFormatter { 8 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 9 | let code = self.0.as_u16(); 10 | let reason = self.0.canonical_reason().unwrap_or("unknown status code"); 11 | 12 | write!(f, "{code} ({reason})") 13 | } 14 | } 15 | 16 | #[cfg(test)] 17 | mod test_fmt { 18 | use super::*; 19 | 20 | #[test] 21 | fn it_should_format_with_reason_where_available() { 22 | let status_code = StatusCode::UNAUTHORIZED; 23 | let debug = StatusCodeFormatter(status_code); 24 | let output = debug.to_string(); 25 | 26 | assert_eq!(output, "401 (Unauthorized)"); 27 | } 28 | 29 | #[test] 30 | fn it_should_provide_only_number_where_reason_is_unavailable() { 31 | let status_code = StatusCode::from_u16(218).unwrap(); // Unofficial Apache status code. 32 | let debug = StatusCodeFormatter(status_code); 33 | let output = debug.to_string(); 34 | 35 | assert_eq!(output, "218 (unknown status code)"); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/util/new_random_tokio_tcp_listener.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use reserve_port::ReservedPort; 3 | use std::net::IpAddr; 4 | use std::net::Ipv4Addr; 5 | use std::net::SocketAddr; 6 | use tokio::net::TcpListener as TokioTcpListener; 7 | 8 | pub(crate) const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); 9 | 10 | /// Binds a [`tokio::net::TcpListener`] on the IP 127.0.0.1, using a random port. 11 | /// 12 | /// This is the best way to pick a local port. 13 | pub fn new_random_tokio_tcp_listener() -> Result { 14 | new_random_tokio_tcp_listener_with_socket_addr() 15 | .map(|(tokio_tcp_listener, _)| tokio_tcp_listener) 16 | } 17 | 18 | /// Binds a [`tokio::net::TcpListener`] on the IP 127.0.0.1, using a random port. 19 | /// 20 | /// It is returned with the [`std::net::SocketAddr`] available. 21 | pub fn new_random_tokio_tcp_listener_with_socket_addr() -> Result<(TokioTcpListener, SocketAddr)> { 22 | let (tcp_listener, random_socket) = 23 | ReservedPort::random_permanently_reserved_tcp(DEFAULT_IP_ADDRESS)?; 24 | 25 | tcp_listener.set_nonblocking(true)?; 26 | let tokio_tcp_listener = TokioTcpListener::from_std(tcp_listener)?; 27 | 28 | Ok((tokio_tcp_listener, random_socket)) 29 | } 30 | -------------------------------------------------------------------------------- /src/transport_layer/transport_layer_builder.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use anyhow::Result; 3 | use reserve_port::ReservedPort; 4 | use std::net::IpAddr; 5 | use std::net::SocketAddr; 6 | use tokio::net::TcpListener; 7 | 8 | use crate::internals::StartingTcpSetup; 9 | 10 | #[derive(Debug, Clone)] 11 | pub struct TransportLayerBuilder { 12 | ip: Option, 13 | port: Option, 14 | } 15 | 16 | impl TransportLayerBuilder { 17 | pub(crate) fn new(ip: Option, port: Option) -> Self { 18 | Self { ip, port } 19 | } 20 | 21 | pub(crate) fn tcp_listener_with_reserved_port( 22 | self, 23 | ) -> Result<(SocketAddr, TcpListener, Option)> { 24 | let setup = StartingTcpSetup::new(self.ip, self.port) 25 | .context("Cannot create socket address for use")?; 26 | 27 | let socket_addr = setup.socket_addr; 28 | let tcp_listener = setup.tcp_listener; 29 | let maybe_reserved_port = setup.maybe_reserved_port; 30 | 31 | Ok((socket_addr, tcp_listener, maybe_reserved_port)) 32 | } 33 | 34 | pub fn tcp_listener(self) -> Result { 35 | let (_, tcp_listener, _) = self.tcp_listener_with_reserved_port()?; 36 | Ok(tcp_listener) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/transport.rs: -------------------------------------------------------------------------------- 1 | use std::net::IpAddr; 2 | 3 | /// Transport is for setting which transport mode for the `TestServer` 4 | /// to use when making requests. 5 | #[derive(Debug, Copy, Clone, Eq, PartialEq)] 6 | pub enum Transport { 7 | /// With this transport mode, `TestRequest` will use a mock HTTP 8 | /// transport. 9 | /// 10 | /// This is the Default Transport type. 11 | MockHttp, 12 | 13 | /// With this transport mode, a real web server will be spun up 14 | /// running on a random port. Requests made using the `TestRequest` 15 | /// will be made over the network stack. 16 | HttpRandomPort, 17 | 18 | /// With this transport mode, a real web server will be spun up. 19 | /// Where you can pick which IP and Port to use for this to bind to. 20 | /// 21 | /// Setting both `ip` and `port` to `None`, is the equivalent of 22 | /// using `Transport::HttpRandomPort`. 23 | HttpIpPort { 24 | /// Set the IP to use for the server. 25 | /// 26 | /// **Defaults** to `127.0.0.1`. 27 | ip: Option, 28 | 29 | /// Set the port number to use for the server. 30 | /// 31 | /// **Defaults** to a _random_ port. 32 | port: Option, 33 | }, 34 | } 35 | 36 | impl Default for Transport { 37 | fn default() -> Self { 38 | Self::MockHttp 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | 3 | use crate::transport_layer::TransportLayer; 4 | use crate::transport_layer::TransportLayerBuilder; 5 | 6 | // mod into_make_service_tower; 7 | 8 | mod into_make_service; 9 | mod into_make_service_with_connect_info; 10 | mod router; 11 | mod serve; 12 | mod with_graceful_shutdown; 13 | 14 | #[cfg(feature = "shuttle")] 15 | mod axum_service; 16 | #[cfg(feature = "shuttle")] 17 | mod shuttle_axum; 18 | 19 | /// 20 | /// This exists to unify how to send mock or real messages to different services. 21 | /// This includes differences between [`Router`](::axum::Router), 22 | /// [`IntoMakeService`](::axum::routing::IntoMakeService), 23 | /// and [`IntoMakeServiceWithConnectInfo`](::axum::extract::connect_info::IntoMakeServiceWithConnectInfo). 24 | /// 25 | /// Implementing this will allow you to use the `TestServer` against other types. 26 | /// 27 | /// **Warning**, this trait may change in a future release. 28 | /// 29 | pub trait IntoTransportLayer: Sized { 30 | fn into_http_transport_layer( 31 | self, 32 | builder: TransportLayerBuilder, 33 | ) -> Result>; 34 | 35 | fn into_mock_transport_layer(self) -> Result>; 36 | 37 | fn into_default_transport( 38 | self, 39 | _builder: TransportLayerBuilder, 40 | ) -> Result> { 41 | self.into_mock_transport_layer() 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | .PHONY: fmt lint test build publish docs codecov 2 | 3 | fmt: 4 | cargo fmt 5 | 6 | lint: 7 | cargo +stable clippy 8 | 9 | test: 10 | cargo +stable check 11 | cargo +stable test --example=example-shuttle --features shuttle 12 | cargo +stable test --example=example-todo 13 | cargo +stable test --example=example-websocket-ping-pong --features ws 14 | cargo +stable test --example=example-websocket-chat --features ws 15 | cargo +stable test --features all 16 | cargo +stable test 17 | 18 | # Check deprecated old-json-diff still works 19 | cargo +stable test --features "old-json-diff" 20 | cargo +stable test --features "ws,old-json-diff" 21 | 22 | # Check minimum version works, excluding shuttle 23 | cargo +1.85 check --features "pretty-assertions,yaml,msgpack,reqwest,typed-routing,ws" 24 | 25 | # Check nightly also works, see https://github.com/JosephLenton/axum-test/issues/133 26 | cargo +nightly check --features all 27 | 28 | # Check the various build variations work 29 | cargo +stable check --no-default-features 30 | cargo +stable check --features all 31 | cargo +stable check --features pretty-assertions 32 | cargo +stable check --features yaml 33 | cargo +stable check --features msgpack 34 | cargo +stable check --features reqwest 35 | cargo +stable check --features shuttle 36 | cargo +stable check --features typed-routing 37 | cargo +stable check --features ws 38 | cargo +stable check --features "ws,old-json-diff" 39 | cargo +stable check --features reqwest 40 | cargo +stable check --features old-json-diff 41 | 42 | cargo +stable clippy --features all 43 | 44 | build: 45 | cargo +stable build 46 | 47 | publish: fmt lint test 48 | cargo publish 49 | 50 | docs: 51 | cargo doc --open --features all 52 | 53 | codecov: 54 | cargo llvm-cov --open 55 | -------------------------------------------------------------------------------- /src/multipart/mod.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This supplies the building blocks for sending multipart forms using 3 | //! [`TestRequest::multipart()`](crate::TestRequest::multipart()). 4 | //! 5 | //! The request body can be built using [`MultipartForm`] and [`Part`]. 6 | //! 7 | //! # Simple example 8 | //! 9 | //! ```rust 10 | //! # async fn test() -> Result<(), Box> { 11 | //! # 12 | //! use axum::Router; 13 | //! use axum_test::TestServer; 14 | //! use axum_test::multipart::MultipartForm; 15 | //! 16 | //! let app = Router::new(); 17 | //! let server = TestServer::new(app)?; 18 | //! 19 | //! let multipart_form = MultipartForm::new() 20 | //! .add_text("name", "Joe") 21 | //! .add_text("animals", "foxes"); 22 | //! 23 | //! let response = server.post(&"/my-form") 24 | //! .multipart(multipart_form) 25 | //! .await; 26 | //! # 27 | //! # Ok(()) } 28 | //! ``` 29 | //! 30 | //! # Sending byte parts 31 | //! 32 | //! ```rust 33 | //! # async fn test() -> Result<(), Box> { 34 | //! # 35 | //! use axum::Router; 36 | //! use axum_test::TestServer; 37 | //! use axum_test::multipart::MultipartForm; 38 | //! use axum_test::multipart::Part; 39 | //! 40 | //! let app = Router::new(); 41 | //! let server = TestServer::new(app)?; 42 | //! 43 | //! let image_bytes = include_bytes!("../../README.md"); 44 | //! let image_part = Part::bytes(image_bytes.as_slice()) 45 | //! .file_name(&"README.md") 46 | //! .mime_type(&"text/markdown"); 47 | //! 48 | //! let multipart_form = MultipartForm::new() 49 | //! .add_part("file", image_part); 50 | //! 51 | //! let response = server.post(&"/my-form") 52 | //! .multipart(multipart_form) 53 | //! .await; 54 | //! # 55 | //! # Ok(()) } 56 | //! ``` 57 | //! 58 | 59 | mod multipart_form; 60 | pub use self::multipart_form::*; 61 | 62 | mod part; 63 | pub use self::part::*; 64 | -------------------------------------------------------------------------------- /src/multipart/multipart_form.rs: -------------------------------------------------------------------------------- 1 | use crate::multipart::Part; 2 | use axum::body::Body as AxumBody; 3 | use rust_multipart_rfc7578_2::client::multipart::Body as CommonMultipartBody; 4 | use rust_multipart_rfc7578_2::client::multipart::Form; 5 | use std::fmt::Display; 6 | use std::io::Cursor; 7 | 8 | #[derive(Debug)] 9 | pub struct MultipartForm { 10 | inner: Form<'static>, 11 | } 12 | 13 | impl MultipartForm { 14 | pub fn new() -> Self { 15 | Default::default() 16 | } 17 | 18 | /// Creates a text part, and adds it to be sent. 19 | pub fn add_text(mut self, name: N, text: T) -> Self 20 | where 21 | N: Display, 22 | T: ToString, 23 | { 24 | self.inner.add_text(name, text.to_string()); 25 | self 26 | } 27 | 28 | /// Adds a new section to this multipart form to be sent. 29 | /// 30 | /// See [`Part`](crate::multipart::Part). 31 | pub fn add_part(mut self, name: N, part: Part) -> Self 32 | where 33 | N: Display, 34 | { 35 | let reader = Cursor::new(part.bytes); 36 | self.inner.add_reader_2( 37 | name, 38 | reader, 39 | part.file_name, 40 | Some(part.mime_type), 41 | part.headers, 42 | ); 43 | 44 | self 45 | } 46 | 47 | /// Returns the content type this form will use when it is sent. 48 | pub fn content_type(&self) -> String { 49 | self.inner.content_type() 50 | } 51 | } 52 | 53 | impl Default for MultipartForm { 54 | fn default() -> Self { 55 | Self { 56 | inner: Default::default(), 57 | } 58 | } 59 | } 60 | 61 | impl From for AxumBody { 62 | fn from(multipart: MultipartForm) -> Self { 63 | let inner_body: CommonMultipartBody = multipart.inner.into(); 64 | AxumBody::from_stream(inner_body) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/internals/transport_layer/http_transport_layer.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::body::Body; 3 | use http::Request; 4 | use http::Response; 5 | use hyper_util::client::legacy::Client; 6 | use reserve_port::ReservedPort; 7 | use std::future::Future; 8 | use std::pin::Pin; 9 | use url::Url; 10 | 11 | use crate::transport_layer::TransportLayer; 12 | use crate::transport_layer::TransportLayerType; 13 | use crate::util::ServeHandle; 14 | 15 | #[derive(Debug)] 16 | pub struct HttpTransportLayer { 17 | #[allow(dead_code)] 18 | serve_handle: ServeHandle, 19 | 20 | #[allow(dead_code)] 21 | maybe_reserved_port: Option, 22 | 23 | url: Url, 24 | } 25 | 26 | impl HttpTransportLayer { 27 | pub(crate) fn new( 28 | serve_handle: ServeHandle, 29 | maybe_reserved_port: Option, 30 | url: Url, 31 | ) -> Self { 32 | Self { 33 | serve_handle, 34 | maybe_reserved_port, 35 | url, 36 | } 37 | } 38 | } 39 | 40 | impl TransportLayer for HttpTransportLayer { 41 | fn send<'a>( 42 | &'a self, 43 | request: Request, 44 | ) -> Pin>> + Send>> { 45 | Box::pin(async { 46 | let client = Client::builder(hyper_util::rt::TokioExecutor::new()).build_http(); 47 | let hyper_response = client.request(request).await?; 48 | 49 | let (parts, response_body) = hyper_response.into_parts(); 50 | let returned_response: Response = 51 | Response::from_parts(parts, Body::new(response_body)); 52 | 53 | Ok(returned_response) 54 | }) 55 | } 56 | 57 | fn url(&self) -> Option<&Url> { 58 | Some(&self.url) 59 | } 60 | 61 | fn transport_layer_type(&self) -> TransportLayerType { 62 | TransportLayerType::Http 63 | } 64 | 65 | fn is_running(&self) -> bool { 66 | !self.serve_handle.is_finished() 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/internals/format_status_code_range.rs: -------------------------------------------------------------------------------- 1 | use http::StatusCode; 2 | use std::fmt::Write; 3 | use std::ops::Bound; 4 | use std::ops::RangeBounds; 5 | 6 | pub fn format_status_code_range(range: R) -> String 7 | where 8 | R: RangeBounds, 9 | { 10 | let mut output = String::new(); 11 | 12 | let start = range.start_bound(); 13 | let end = range.end_bound(); 14 | 15 | match start { 16 | Bound::Included(code) | Bound::Excluded(code) => { 17 | write!(output, "{}", code.as_u16()).expect("Failed to build debug string"); 18 | } 19 | Bound::Unbounded => {} 20 | }; 21 | 22 | write!(output, "..").expect("Failed to build debug string"); 23 | 24 | match end { 25 | Bound::Included(code) => { 26 | write!(output, "={}", code.as_u16()).expect("Failed to build debug string"); 27 | } 28 | Bound::Excluded(code) => { 29 | write!(output, "{}", code.as_u16()).expect("Failed to build debug string"); 30 | } 31 | Bound::Unbounded => {} 32 | }; 33 | 34 | output 35 | } 36 | 37 | #[cfg(test)] 38 | mod test_format_status_code_range { 39 | use super::*; 40 | 41 | #[test] 42 | fn it_should_format_range() { 43 | let output = format_status_code_range(StatusCode::OK..StatusCode::IM_USED); 44 | assert_eq!(output, "200..226"); 45 | } 46 | 47 | #[test] 48 | fn it_should_format_range_inclusive() { 49 | let output = format_status_code_range(StatusCode::OK..=StatusCode::IM_USED); 50 | assert_eq!(output, "200..=226"); 51 | } 52 | 53 | #[test] 54 | fn it_should_format_range_from() { 55 | let output = format_status_code_range(StatusCode::OK..); 56 | assert_eq!(output, "200.."); 57 | } 58 | 59 | #[test] 60 | fn it_should_format_range_to() { 61 | let output = format_status_code_range(..StatusCode::IM_USED); 62 | assert_eq!(output, "..226"); 63 | } 64 | 65 | #[test] 66 | fn it_should_format_range_to_inclusive() { 67 | let output = format_status_code_range(..=StatusCode::IM_USED); 68 | assert_eq!(output, "..=226"); 69 | } 70 | 71 | #[test] 72 | fn it_should_format_range_full() { 73 | let output = format_status_code_range(..); 74 | assert_eq!(output, ".."); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/internals/transport_layer/mock_transport_layer.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Error as AnyhowError; 2 | use anyhow::Result; 3 | use axum::body::Body; 4 | use axum::response::Response as AxumResponse; 5 | use bytes::Bytes; 6 | use http::Request; 7 | use http::Response; 8 | use std::fmt::Debug; 9 | use std::future::Future; 10 | use std::pin::Pin; 11 | use tower::Service; 12 | use tower::util::ServiceExt; 13 | 14 | use crate::transport_layer::TransportLayer; 15 | use crate::transport_layer::TransportLayerType; 16 | 17 | pub struct MockTransportLayer { 18 | service: S, 19 | } 20 | 21 | impl MockTransportLayer 22 | where 23 | S: Service, Response = RouterService> + Clone + Send + Sync, 24 | AnyhowError: From, 25 | S::Future: Send, 26 | RouterService: Service, Response = AxumResponse>, 27 | { 28 | pub(crate) fn new(service: S) -> Self { 29 | Self { service } 30 | } 31 | } 32 | 33 | impl TransportLayer for MockTransportLayer 34 | where 35 | S: Service, Response = RouterService> + Clone + Send + Sync + 'static, 36 | AnyhowError: From, 37 | S::Future: Send + Sync, 38 | RouterService: Service, Response = AxumResponse> + Send, 39 | RouterService::Future: Send, 40 | AnyhowError: From, 41 | { 42 | fn send<'a>( 43 | &'a self, 44 | request: Request, 45 | ) -> Pin>> + Send>> { 46 | Box::pin(async { 47 | let body: Body = Bytes::new().into(); 48 | let empty_request = Request::builder() 49 | .body(body) 50 | .expect("should build empty request"); 51 | 52 | let service = self.service.clone(); 53 | let router = service.oneshot(empty_request).await?; 54 | 55 | let response = router.oneshot(request).await?; 56 | Ok(response) 57 | }) 58 | } 59 | 60 | fn transport_layer_type(&self) -> TransportLayerType { 61 | TransportLayerType::Mock 62 | } 63 | 64 | /// This will always return true. 65 | #[inline(always)] 66 | fn is_running(&self) -> bool { 67 | true 68 | } 69 | } 70 | 71 | impl Debug for MockTransportLayer { 72 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 73 | write!(f, "MockTransportLayer {{ service: {{unknown}} }}") 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/axum_service.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::Router; 3 | use shuttle_axum::AxumService; 4 | 5 | use crate::transport_layer::IntoTransportLayer; 6 | use crate::transport_layer::TransportLayer; 7 | use crate::transport_layer::TransportLayerBuilder; 8 | 9 | impl IntoTransportLayer for AxumService { 10 | fn into_http_transport_layer( 11 | self, 12 | builder: TransportLayerBuilder, 13 | ) -> Result> { 14 | Router::into_http_transport_layer(self.0, builder) 15 | } 16 | 17 | fn into_mock_transport_layer(self) -> Result> { 18 | Router::into_mock_transport_layer(self.0) 19 | } 20 | } 21 | 22 | #[cfg(test)] 23 | mod test_into_http_transport_layer_for_axum_service { 24 | use super::*; 25 | 26 | use axum::Router; 27 | use axum::extract::State; 28 | use axum::routing::get; 29 | 30 | use crate::TestServer; 31 | 32 | async fn get_state(State(count): State) -> String { 33 | format!("count is {}", count) 34 | } 35 | 36 | #[tokio::test] 37 | async fn it_should_run() { 38 | // Build an application with a route. 39 | let app: AxumService = Router::new() 40 | .route("/count", get(get_state)) 41 | .with_state(123) 42 | .into(); 43 | 44 | // Run the server. 45 | let server = TestServer::builder() 46 | .http_transport() 47 | .build(app) 48 | .expect("Should create test server"); 49 | 50 | // Get the request. 51 | server.get(&"/count").await.assert_text(&"count is 123"); 52 | } 53 | } 54 | 55 | #[cfg(test)] 56 | mod test_into_mock_transport_layer_for_axum_service { 57 | use super::*; 58 | 59 | use axum::Router; 60 | use axum::extract::State; 61 | use axum::routing::get; 62 | 63 | use crate::TestServer; 64 | 65 | async fn get_state(State(count): State) -> String { 66 | format!("count is {}", count) 67 | } 68 | 69 | #[tokio::test] 70 | async fn it_should_run() { 71 | // Build an application with a route. 72 | let app: AxumService = Router::new() 73 | .route("/count", get(get_state)) 74 | .with_state(123) 75 | .into(); 76 | 77 | // Run the server. 78 | let server = TestServer::builder() 79 | .mock_transport() 80 | .build(app) 81 | .expect("Should create test server"); 82 | 83 | // Get the request. 84 | server.get(&"/count").await.assert_text(&"count is 123"); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/internals/try_into_range_bounds.rs: -------------------------------------------------------------------------------- 1 | use std::convert::Infallible; 2 | use std::fmt::Debug; 3 | use std::ops::Range; 4 | use std::ops::RangeBounds; 5 | use std::ops::RangeFrom; 6 | use std::ops::RangeFull; 7 | use std::ops::RangeInclusive; 8 | use std::ops::RangeTo; 9 | use std::ops::RangeToInclusive; 10 | 11 | pub trait TryIntoRangeBounds { 12 | type TargetRange: RangeBounds; 13 | type Error: Debug; 14 | 15 | fn try_into_range_bounds(self) -> Result; 16 | } 17 | 18 | impl TryIntoRangeBounds for Range 19 | where 20 | A: TryInto, 21 | A::Error: Debug, 22 | { 23 | type TargetRange = Range; 24 | type Error = >::Error; 25 | 26 | fn try_into_range_bounds(self) -> Result { 27 | Ok(self.start.try_into()?..self.end.try_into()?) 28 | } 29 | } 30 | 31 | impl TryIntoRangeBounds for RangeFrom 32 | where 33 | A: TryInto, 34 | A::Error: Debug, 35 | { 36 | type TargetRange = RangeFrom; 37 | type Error = >::Error; 38 | 39 | fn try_into_range_bounds(self) -> Result { 40 | Ok(self.start.try_into()?..) 41 | } 42 | } 43 | 44 | impl TryIntoRangeBounds for RangeTo 45 | where 46 | A: TryInto, 47 | A::Error: Debug, 48 | { 49 | type TargetRange = RangeTo; 50 | type Error = >::Error; 51 | 52 | fn try_into_range_bounds(self) -> Result { 53 | Ok(..self.end.try_into()?) 54 | } 55 | } 56 | 57 | impl TryIntoRangeBounds for RangeInclusive 58 | where 59 | A: TryInto, 60 | A::Error: Debug, 61 | { 62 | type TargetRange = RangeInclusive; 63 | type Error = >::Error; 64 | 65 | fn try_into_range_bounds(self) -> Result { 66 | let (start, end) = self.into_inner(); 67 | Ok(start.try_into()?..=end.try_into()?) 68 | } 69 | } 70 | 71 | impl TryIntoRangeBounds for RangeToInclusive 72 | where 73 | A: TryInto, 74 | A::Error: Debug, 75 | { 76 | type TargetRange = RangeToInclusive; 77 | type Error = >::Error; 78 | 79 | fn try_into_range_bounds(self) -> Result { 80 | Ok(..=self.end.try_into()?) 81 | } 82 | } 83 | 84 | impl TryIntoRangeBounds for RangeFull { 85 | type TargetRange = RangeFull; 86 | type Error = Infallible; 87 | 88 | fn try_into_range_bounds(self) -> Result { 89 | Ok(self) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/internals/request_path_formatter.rs: -------------------------------------------------------------------------------- 1 | use http::Method; 2 | use std::fmt; 3 | 4 | use crate::internals::QueryParamsStore; 5 | 6 | #[derive(Debug, Clone, PartialEq)] 7 | pub struct RequestPathFormatter<'a> { 8 | method: &'a Method, 9 | 10 | /// This is the path that the user requested. 11 | user_requested_path: &'a str, 12 | query_params: Option<&'a QueryParamsStore>, 13 | } 14 | 15 | impl<'a> RequestPathFormatter<'a> { 16 | pub fn new( 17 | method: &'a Method, 18 | user_requested_path: &'a str, 19 | query_params: Option<&'a QueryParamsStore>, 20 | ) -> Self { 21 | Self { 22 | method, 23 | user_requested_path, 24 | query_params, 25 | } 26 | } 27 | } 28 | 29 | impl fmt::Display for RequestPathFormatter<'_> { 30 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 31 | let method = &self.method; 32 | let user_requested_path = &self.user_requested_path; 33 | 34 | match self.query_params { 35 | None => { 36 | write!(f, "{method} {user_requested_path}") 37 | } 38 | Some(query_params) => { 39 | if query_params.is_empty() { 40 | write!(f, "{method} {user_requested_path}") 41 | } else { 42 | write!(f, "{method} {user_requested_path}?{query_params}") 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | #[cfg(test)] 50 | mod test_fmt { 51 | use super::*; 52 | 53 | #[test] 54 | fn it_should_format_with_path_given() { 55 | let query_params = QueryParamsStore::new(); 56 | let debug = RequestPathFormatter::new(&Method::GET, &"/donkeys", Some(&query_params)); 57 | let output = debug.to_string(); 58 | 59 | assert_eq!(output, "GET /donkeys"); 60 | } 61 | 62 | #[test] 63 | fn it_should_format_with_path_given_and_no_query_params() { 64 | let debug = RequestPathFormatter::new(&Method::GET, &"/donkeys", None); 65 | let output = debug.to_string(); 66 | 67 | assert_eq!(output, "GET /donkeys"); 68 | } 69 | 70 | #[test] 71 | fn it_should_format_with_path_given_and_query_params() { 72 | let mut query_params = QueryParamsStore::new(); 73 | query_params.add_raw("value=123".to_string()); 74 | query_params.add_raw("another-value".to_string()); 75 | 76 | let debug = RequestPathFormatter::new(&Method::GET, &"/donkeys", Some(&query_params)); 77 | let output = debug.to_string(); 78 | 79 | assert_eq!(output, "GET /donkeys?value=123&another-value"); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/shuttle_axum.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use shuttle_axum::ShuttleAxum; 3 | 4 | use crate::transport_layer::IntoTransportLayer; 5 | use crate::transport_layer::TransportLayer; 6 | use crate::transport_layer::TransportLayerBuilder; 7 | 8 | impl IntoTransportLayer for ShuttleAxum { 9 | fn into_http_transport_layer( 10 | self, 11 | builder: TransportLayerBuilder, 12 | ) -> Result> { 13 | self.map_err(Into::into) 14 | .and_then(|axum_service| axum_service.into_http_transport_layer(builder)) 15 | } 16 | 17 | fn into_mock_transport_layer(self) -> Result> { 18 | self.map_err(Into::into) 19 | .and_then(|axum_service| axum_service.into_mock_transport_layer()) 20 | } 21 | } 22 | 23 | #[cfg(test)] 24 | mod test_into_http_transport_layer_for_shuttle_axum { 25 | use super::*; 26 | 27 | use axum::Router; 28 | use axum::extract::State; 29 | use axum::routing::get; 30 | use shuttle_axum::AxumService; 31 | 32 | use crate::TestServer; 33 | 34 | async fn get_state(State(count): State) -> String { 35 | format!("count is {}", count) 36 | } 37 | 38 | #[tokio::test] 39 | async fn it_should_run() { 40 | // Build an application with a route. 41 | let router = Router::new() 42 | .route("/count", get(get_state)) 43 | .with_state(123); 44 | let app: ShuttleAxum = Ok(AxumService::from(router)); 45 | 46 | // Run the server. 47 | let server = TestServer::builder() 48 | .http_transport() 49 | .build(app) 50 | .expect("Should create test server"); 51 | 52 | // Get the request. 53 | server.get(&"/count").await.assert_text(&"count is 123"); 54 | } 55 | } 56 | 57 | #[cfg(test)] 58 | mod test_into_mock_transport_layer_for_shuttle_axum { 59 | use super::*; 60 | 61 | use axum::Router; 62 | use axum::extract::State; 63 | use axum::routing::get; 64 | use shuttle_axum::AxumService; 65 | 66 | use crate::TestServer; 67 | 68 | async fn get_state(State(count): State) -> String { 69 | format!("count is {}", count) 70 | } 71 | 72 | #[tokio::test] 73 | async fn it_should_run() { 74 | // Build an application with a route. 75 | let router = Router::new() 76 | .route("/count", get(get_state)) 77 | .with_state(123); 78 | let app: ShuttleAxum = Ok(AxumService::from(router)); 79 | 80 | // Run the server. 81 | let server = TestServer::builder() 82 | .mock_transport() 83 | .build(app) 84 | .expect("Should create test server"); 85 | 86 | // Get the request. 87 | server.get(&"/count").await.assert_text(&"count is 123"); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/router.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::Router; 3 | 4 | use crate::transport_layer::IntoTransportLayer; 5 | use crate::transport_layer::TransportLayer; 6 | use crate::transport_layer::TransportLayerBuilder; 7 | 8 | impl IntoTransportLayer for Router<()> { 9 | fn into_http_transport_layer( 10 | self, 11 | builder: TransportLayerBuilder, 12 | ) -> Result> { 13 | self.into_make_service().into_http_transport_layer(builder) 14 | } 15 | 16 | fn into_mock_transport_layer(self) -> Result> { 17 | self.into_make_service().into_mock_transport_layer() 18 | } 19 | } 20 | 21 | #[cfg(test)] 22 | mod test_into_http_transport_layer { 23 | use axum::Router; 24 | use axum::extract::State; 25 | use axum::routing::get; 26 | 27 | use crate::TestServer; 28 | 29 | async fn get_ping() -> &'static str { 30 | "pong!" 31 | } 32 | 33 | async fn get_state(State(count): State) -> String { 34 | format!("count is {}", count) 35 | } 36 | 37 | #[tokio::test] 38 | async fn it_should_create_and_test_with_make_into_service() { 39 | // Build an application with a route. 40 | let app: Router = Router::new().route("/ping", get(get_ping)); 41 | 42 | // Run the server. 43 | let server = TestServer::builder() 44 | .http_transport() 45 | .build(app) 46 | .expect("Should create test server"); 47 | 48 | // Get the request. 49 | server.get(&"/ping").await.assert_text(&"pong!"); 50 | } 51 | 52 | #[tokio::test] 53 | async fn it_should_create_and_test_with_make_into_service_with_state() { 54 | // Build an application with a route. 55 | let app: Router = Router::new() 56 | .route("/count", get(get_state)) 57 | .with_state(123); 58 | 59 | // Run the server. 60 | let server = TestServer::builder() 61 | .http_transport() 62 | .build(app) 63 | .expect("Should create test server"); 64 | 65 | // Get the request. 66 | server.get(&"/count").await.assert_text(&"count is 123"); 67 | } 68 | } 69 | 70 | #[cfg(test)] 71 | mod test_into_mock_transport_layer_for_router { 72 | use axum::Router; 73 | use axum::extract::State; 74 | use axum::routing::get; 75 | 76 | use crate::TestServer; 77 | 78 | async fn get_ping() -> &'static str { 79 | "pong!" 80 | } 81 | 82 | async fn get_state(State(count): State) -> String { 83 | format!("count is {}", count) 84 | } 85 | 86 | #[tokio::test] 87 | async fn it_should_create_and_test_with_make_into_service() { 88 | // Build an application with a route. 89 | let app: Router = Router::new().route("/ping", get(get_ping)); 90 | 91 | // Run the server. 92 | let server = TestServer::builder() 93 | .mock_transport() 94 | .build(app) 95 | .expect("Should create test server"); 96 | 97 | // Get the request. 98 | server.get(&"/ping").await.assert_text(&"pong!"); 99 | } 100 | 101 | #[tokio::test] 102 | async fn it_should_create_and_test_with_make_into_service_with_state() { 103 | // Build an application with a route. 104 | let app: Router = Router::new() 105 | .route("/count", get(get_state)) 106 | .with_state(123); 107 | 108 | // Run the server. 109 | let server = TestServer::builder() 110 | .mock_transport() 111 | .build(app) 112 | .expect("Should create test server"); 113 | 114 | // Get the request. 115 | server.get(&"/count").await.assert_text(&"count is 123"); 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "axum-test" 3 | authors = ["Joseph Lenton "] 4 | version = "18.4.1" 5 | rust-version = "1.85" 6 | edition = "2024" 7 | license = "MIT" 8 | description = "For spinning up and testing Axum servers" 9 | keywords = ["testing", "test", "axum"] 10 | categories = ["web-programming::http-server", "development-tools::testing"] 11 | repository = "https://github.com/JosephLenton/axum-test" 12 | documentation = "https://docs.rs/axum-test" 13 | readme = "README.md" 14 | 15 | [package.metadata.docs.rs] 16 | all-features = true 17 | rustdoc-args = ["--cfg", "docsrs"] 18 | 19 | [[example]] 20 | name = "example-shuttle" 21 | path = "examples/example-shuttle/main.rs" 22 | required-features = ["shuttle"] 23 | 24 | [[example]] 25 | name = "example-websocket-chat" 26 | path = "examples/example-websocket-chat/main.rs" 27 | required-features = ["ws"] 28 | 29 | [[example]] 30 | name = "example-websocket-ping-pong" 31 | path = "examples/example-websocket-ping-pong/main.rs" 32 | required-features = ["ws"] 33 | 34 | [features] 35 | default = ["pretty-assertions"] 36 | 37 | all = ["pretty-assertions", "yaml", "msgpack", "reqwest", "shuttle", "typed-routing", "ws"] 38 | 39 | pretty-assertions = ["dep:pretty_assertions"] 40 | yaml = ["dep:serde_yaml"] 41 | msgpack = ["dep:rmp-serde"] 42 | reqwest = ["dep:reqwest"] 43 | shuttle = ["dep:shuttle-axum"] 44 | typed-routing = ["dep:axum-extra"] 45 | ws = ["axum/ws", "tokio/time", "dep:uuid", "dep:base64", "dep:tokio-tungstenite", "dep:futures-util"] 46 | 47 | # Deprecated, and will be removed in the future. 48 | old-json-diff = ["dep:assert-json-diff"] 49 | 50 | [dependencies] 51 | axum = { version = "0.8.7", features = [] } 52 | anyhow = "1.0" 53 | bytes = "1.11" 54 | bytesize = "2.3" 55 | cookie = "0.18" 56 | expect-json = "1.7.1" 57 | http = "1.3" 58 | http-body-util = "0.1" 59 | hyper-util = { version = "0.1", features = ["client", "http1", "client-legacy"] } 60 | hyper = { version = "1.8", features = ["http1"] } 61 | mime = "0.3" 62 | rust-multipart-rfc7578_2 = "0.8" 63 | reserve-port = "2.3" 64 | serde = "1.0" 65 | serde_json = "1.0" 66 | serde_urlencoded = "0.7" 67 | smallvec = "1.15" 68 | tokio = { version = "1.48", features = ["rt"] } 69 | tower = { version = "0.5", features = ["util", "make"] } 70 | url = "2.5" 71 | 72 | # Pretty Assertions 73 | pretty_assertions = { version = "1.4", optional = true } 74 | 75 | # Yaml 76 | serde_yaml = { version = "0.9", optional = true } 77 | 78 | # Shuttle 79 | shuttle-axum = { version = "0.57", optional = true } 80 | 81 | # MsgPack 82 | rmp-serde = { version = "1.3", optional = true } 83 | 84 | # Typed Routing 85 | axum-extra = { version = "0.12", optional = true, features = ["routing", "typed-routing"] } 86 | 87 | # WebSockets 88 | uuid = { version = "1.18", optional = true, features = ["v4"]} 89 | base64 = { version = "0.22", optional = true } 90 | futures-util = { version = "0.3", optional = true } 91 | tokio-tungstenite = { version = "0.28", optional = true } 92 | 93 | # Reqwest 94 | reqwest = { version = "0.12", optional = true, features = ["cookies", "json", "stream", "multipart", "rustls-tls"] } 95 | 96 | # Old Json Diff 97 | assert-json-diff = { version = "2.0", optional = true } 98 | 99 | [dev-dependencies] 100 | axum = { version = "0.8", features = ["multipart", "tokio", "ws"] } 101 | axum-extra = { version = "0.12", features = ["cookie", "typed-routing", "query"] } 102 | axum-msgpack = "0.5" 103 | axum-yaml = "0.5" 104 | futures-util = "0.3" 105 | local-ip-address = "0.6" 106 | rand = { version = "0.9", features = ["small_rng"] } 107 | regex = "1.12" 108 | serde-email = { version = "3.1", features = ["serde"] } 109 | shuttle-axum = "0.57" 110 | shuttle-runtime = "0.57" 111 | tokio = { version = "1.48", features = ["rt", "rt-multi-thread", "sync", "time", "macros"] } 112 | tower-http = { version = "0.6", features = ["normalize-path"] } 113 | -------------------------------------------------------------------------------- /examples/example-websocket-ping-pong/main.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This is a simple WebSocket example Application. 3 | //! You send it data, and it will send it back. 4 | //! 5 | //! At the bottom of this file are a series of tests for using websockets. 6 | //! 7 | //! ```bash 8 | //! # To run it's tests: 9 | //! cargo test --example=example-websocket-ping-pong --features ws 10 | //! ``` 11 | //! 12 | 13 | use anyhow::Result; 14 | use axum::Router; 15 | use axum::extract::WebSocketUpgrade; 16 | use axum::extract::ws::WebSocket; 17 | use axum::response::Response; 18 | use axum::routing::get; 19 | use axum::serve::serve; 20 | use std::net::IpAddr; 21 | use std::net::Ipv4Addr; 22 | use std::net::SocketAddr; 23 | use tokio::net::TcpListener; 24 | 25 | #[cfg(test)] 26 | use axum_test::TestServer; 27 | 28 | const PORT: u16 = 8080; 29 | 30 | #[tokio::main] 31 | async fn main() { 32 | let result: Result<()> = { 33 | let app = new_app(); 34 | 35 | // Start! 36 | let ip_address = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); 37 | let address = SocketAddr::new(ip_address, PORT); 38 | let listener = TcpListener::bind(address).await.unwrap(); 39 | serve(listener, app.into_make_service()).await.unwrap(); 40 | 41 | Ok(()) 42 | }; 43 | 44 | match &result { 45 | Err(err) => eprintln!("{}", err), 46 | _ => {} 47 | }; 48 | } 49 | 50 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 51 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 52 | } 53 | 54 | async fn handle_ping_pong(mut socket: WebSocket) { 55 | while let Some(msg) = socket.recv().await { 56 | let msg = if let Ok(msg) = msg { 57 | msg 58 | } else { 59 | // client disconnected 60 | return; 61 | }; 62 | 63 | if socket.send(msg).await.is_err() { 64 | // client disconnected 65 | return; 66 | } 67 | } 68 | } 69 | 70 | pub(crate) fn new_app() -> Router { 71 | Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)) 72 | } 73 | 74 | #[cfg(test)] 75 | fn new_test_app() -> TestServer { 76 | let app = new_app(); 77 | TestServer::builder() 78 | .http_transport() // Important! It must be a HTTP Transport here. 79 | .build(app) 80 | .unwrap() 81 | } 82 | 83 | #[cfg(test)] 84 | mod test_websockets_ping_pong { 85 | use super::*; 86 | 87 | use serde_json::json; 88 | 89 | #[tokio::test] 90 | async fn it_should_start_a_websocket_connection() { 91 | let server = new_test_app(); 92 | 93 | let response = server.get_websocket(&"/ws-ping-pong").await; 94 | 95 | response.assert_status_switching_protocols(); 96 | } 97 | 98 | #[tokio::test] 99 | async fn it_should_ping_pong_text() { 100 | let server = new_test_app(); 101 | 102 | let mut websocket = server 103 | .get_websocket(&"/ws-ping-pong") 104 | .await 105 | .into_websocket() 106 | .await; 107 | 108 | websocket.send_text("Hello!").await; 109 | websocket.assert_receive_text("Hello!").await; 110 | } 111 | 112 | #[tokio::test] 113 | async fn it_should_ping_pong_json() { 114 | let server = new_test_app(); 115 | 116 | let mut websocket = server 117 | .get_websocket(&"/ws-ping-pong") 118 | .await 119 | .into_websocket() 120 | .await; 121 | 122 | websocket 123 | .send_json(&json!({ 124 | "hello": "world", 125 | "numbers": [1, 2, 3], 126 | })) 127 | .await; 128 | websocket 129 | .assert_receive_json(&json!({ 130 | "hello": "world", 131 | "numbers": [1, 2, 3], 132 | })) 133 | .await; 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/internals/starting_tcp_setup.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use anyhow::Result; 3 | use reserve_port::ReservedPort; 4 | use std::net::IpAddr; 5 | use std::net::Ipv4Addr; 6 | use std::net::SocketAddr; 7 | use std::net::TcpListener as StdTcpListener; 8 | use tokio::net::TcpListener as TokioTcpListener; 9 | 10 | pub const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); 11 | 12 | #[derive(Debug)] 13 | pub struct StartingTcpSetup { 14 | pub maybe_reserved_port: Option, 15 | pub socket_addr: SocketAddr, 16 | pub tcp_listener: TokioTcpListener, 17 | } 18 | 19 | impl StartingTcpSetup { 20 | pub fn new(maybe_ip: Option, maybe_port: Option) -> Result { 21 | let ip = maybe_ip.unwrap_or(DEFAULT_IP_ADDRESS); 22 | 23 | maybe_port 24 | .map(|port| Self::new_with_port(ip, port)) 25 | .unwrap_or_else(|| Self::new_without_port(ip)) 26 | } 27 | 28 | fn new_with_port(ip: IpAddr, port: u16) -> Result { 29 | ReservedPort::reserve_port(port)?; 30 | let socket_addr = SocketAddr::new(ip, port); 31 | let std_tcp_listener = StdTcpListener::bind(socket_addr) 32 | .context("Failed to create TCPListener for TestServer")?; 33 | std_tcp_listener.set_nonblocking(true)?; 34 | let tokio_tcp_listener = TokioTcpListener::from_std(std_tcp_listener)?; 35 | 36 | Ok(Self { 37 | maybe_reserved_port: None, 38 | socket_addr, 39 | tcp_listener: tokio_tcp_listener, 40 | }) 41 | } 42 | 43 | fn new_without_port(ip: IpAddr) -> Result { 44 | let (reserved_port, std_tcp_listener) = ReservedPort::random_with_tcp(ip)?; 45 | let socket_addr = SocketAddr::new(ip, reserved_port.port()); 46 | std_tcp_listener.set_nonblocking(true)?; 47 | let tokio_tcp_listener = TokioTcpListener::from_std(std_tcp_listener)?; 48 | 49 | Ok(Self { 50 | maybe_reserved_port: Some(reserved_port), 51 | socket_addr, 52 | tcp_listener: tokio_tcp_listener, 53 | }) 54 | } 55 | } 56 | 57 | #[cfg(test)] 58 | mod test_new { 59 | use super::*; 60 | use regex::Regex; 61 | use std::net::Ipv4Addr; 62 | 63 | #[tokio::test] 64 | async fn it_should_create_default_ip_with_random_port_when_none() { 65 | let ip = None; 66 | let port = None; 67 | 68 | let setup = StartingTcpSetup::new(ip, port).unwrap(); 69 | let addr = setup.socket_addr.to_string(); 70 | 71 | let regex = Regex::new("^127\\.0\\.0\\.1:[0-9]+$").unwrap(); 72 | let is_match = regex.is_match(&addr); 73 | assert!(is_match); 74 | } 75 | 76 | #[tokio::test] 77 | async fn it_should_create_ip_with_random_port_when_ip_given() { 78 | let ip = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); 79 | let port = None; 80 | 81 | let setup = StartingTcpSetup::new(ip, port).unwrap(); 82 | let addr = setup.socket_addr.to_string(); 83 | 84 | let regex = Regex::new("^127\\.0\\.0\\.1:[0-9]+$").unwrap(); 85 | let is_match = regex.is_match(&addr); 86 | assert!(is_match); 87 | } 88 | 89 | #[tokio::test] 90 | async fn it_should_create_default_ip_with_port_when_port_given() { 91 | let ip = None; 92 | let port = Some(8123); 93 | 94 | let setup = StartingTcpSetup::new(ip, port).unwrap(); 95 | let addr = setup.socket_addr.to_string(); 96 | 97 | assert_eq!(addr, "127.0.0.1:8123"); 98 | } 99 | 100 | #[tokio::test] 101 | async fn it_should_create_ip_port_given_when_both_given() { 102 | let ip = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); 103 | let port = Some(8124); 104 | 105 | let setup = StartingTcpSetup::new(ip, port).unwrap(); 106 | let addr = setup.socket_addr.to_string(); 107 | 108 | assert_eq!(addr, "127.0.0.1:8124"); 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /src/internals/query_params_store.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use serde::Serialize; 3 | use smallvec::SmallVec; 4 | use std::fmt::Display; 5 | use std::fmt::Formatter; 6 | use std::fmt::Result as FmtResult; 7 | 8 | #[derive(Debug, Clone, PartialEq)] 9 | pub struct QueryParamsStore { 10 | query_params: SmallVec<[String; 0]>, 11 | } 12 | 13 | impl QueryParamsStore { 14 | pub fn new() -> Self { 15 | Self { 16 | query_params: SmallVec::new(), 17 | } 18 | } 19 | 20 | pub fn add(&mut self, query_params: V) -> Result<()> 21 | where 22 | V: Serialize, 23 | { 24 | let value_raw = ::serde_urlencoded::to_string(query_params)?; 25 | self.add_raw(value_raw); 26 | 27 | Ok(()) 28 | } 29 | 30 | pub fn add_raw(&mut self, value_raw: String) { 31 | self.query_params.push(value_raw); 32 | } 33 | 34 | pub fn clear(&mut self) { 35 | self.query_params.clear(); 36 | } 37 | 38 | pub fn is_empty(&self) -> bool { 39 | self.query_params.is_empty() 40 | } 41 | 42 | pub fn has_content(&self) -> bool { 43 | !self.is_empty() 44 | } 45 | } 46 | 47 | impl Display for QueryParamsStore { 48 | fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { 49 | let mut is_joining = false; 50 | for query in &self.query_params { 51 | if is_joining { 52 | write!(f, "&")?; 53 | } 54 | 55 | write!(f, "{query}")?; 56 | is_joining = true; 57 | } 58 | 59 | Ok(()) 60 | } 61 | } 62 | 63 | #[cfg(test)] 64 | mod test_add { 65 | use super::*; 66 | 67 | #[test] 68 | fn it_should_add_multiple_key_values() { 69 | let mut params = QueryParamsStore::new(); 70 | 71 | params 72 | .add(&[("key", "value"), ("another", "value")]) 73 | .unwrap(); 74 | 75 | assert_eq!("key=value&another=value", params.to_string()); 76 | } 77 | 78 | #[test] 79 | fn it_should_add_multiple_calls() { 80 | let mut params = QueryParamsStore::new(); 81 | 82 | params.add(&[("key", "value")]).unwrap(); 83 | params.add(&[("another", "value")]).unwrap(); 84 | 85 | assert_eq!("key=value&another=value", params.to_string()); 86 | } 87 | 88 | #[test] 89 | fn it_should_reject_raw_string() { 90 | let mut params = QueryParamsStore::new(); 91 | 92 | let result = params.add("key=value"); 93 | 94 | assert!(result.is_err()); 95 | } 96 | 97 | #[test] 98 | fn it_should_add_query_param_strings_deserialized() { 99 | let mut params = QueryParamsStore::new(); 100 | 101 | params.add(&[("key", "value&another=value")]).unwrap(); 102 | 103 | assert_eq!("key=value%26another%3Dvalue", params.to_string()); 104 | } 105 | } 106 | 107 | #[cfg(test)] 108 | mod test_add_raw { 109 | use crate::internals::QueryParamsStore; 110 | 111 | #[test] 112 | fn it_should_add_key_value_pairs_correctly() { 113 | let mut params = QueryParamsStore::new(); 114 | 115 | params.add_raw("key=value".to_string()); 116 | params.add_raw("another=value".to_string()); 117 | 118 | assert_eq!("key=value&another=value", params.to_string()); 119 | } 120 | 121 | #[test] 122 | fn it_should_add_single_keys_correctly() { 123 | let mut params = QueryParamsStore::new(); 124 | 125 | params.add_raw("key".to_string()); 126 | params.add_raw("another".to_string()); 127 | 128 | assert_eq!("key&another", params.to_string()); 129 | } 130 | 131 | #[test] 132 | fn it_should_add_query_param_strings_correctly() { 133 | let mut params = QueryParamsStore::new(); 134 | 135 | params.add_raw("key=value&another=value".to_string()); 136 | params.add_raw("more=value".to_string()); 137 | 138 | assert_eq!("key=value&another=value&more=value", params.to_string()); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /src/test_server/server_shared_state.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use anyhow::Result; 3 | use cookie::Cookie; 4 | use cookie::CookieJar; 5 | use http::HeaderName; 6 | use http::HeaderValue; 7 | use serde::Serialize; 8 | use std::sync::Arc; 9 | use std::sync::Mutex; 10 | 11 | use crate::internals::QueryParamsStore; 12 | use crate::internals::with_this_mut; 13 | 14 | #[derive(Debug)] 15 | pub(crate) struct ServerSharedState { 16 | scheme: Option, 17 | cookies: CookieJar, 18 | query_params: QueryParamsStore, 19 | headers: Vec<(HeaderName, HeaderValue)>, 20 | } 21 | 22 | impl ServerSharedState { 23 | pub(crate) fn new() -> Self { 24 | Self { 25 | scheme: None, 26 | cookies: CookieJar::new(), 27 | query_params: QueryParamsStore::new(), 28 | headers: Vec::new(), 29 | } 30 | } 31 | 32 | pub(crate) fn scheme(&self) -> Option<&str> { 33 | self.scheme.as_deref() 34 | } 35 | 36 | pub(crate) fn cookies(&self) -> &CookieJar { 37 | &self.cookies 38 | } 39 | 40 | pub(crate) fn query_params(&self) -> &QueryParamsStore { 41 | &self.query_params 42 | } 43 | 44 | pub(crate) fn headers(&self) -> &Vec<(HeaderName, HeaderValue)> { 45 | &self.headers 46 | } 47 | 48 | /// Adds the given cookies. 49 | /// 50 | /// They will be stored over the top of the existing cookies. 51 | pub(crate) fn add_cookies_by_header<'a, I>( 52 | this: &Arc>, 53 | cookie_headers: I, 54 | ) -> Result<()> 55 | where 56 | I: Iterator, 57 | { 58 | with_this_mut(this, "add_cookies_by_header", |this| { 59 | for cookie_header in cookie_headers { 60 | let cookie_header_str = cookie_header 61 | .to_str() 62 | .context("Reading cookie header for storing in the `TestServer`") 63 | .unwrap(); 64 | 65 | let cookie: Cookie<'static> = Cookie::parse(cookie_header_str)?.into_owned(); 66 | this.cookies.add(cookie); 67 | } 68 | 69 | Ok(()) as Result<()> 70 | })? 71 | } 72 | 73 | /// Adds the given cookies. 74 | /// 75 | /// They will be stored over the top of the existing cookies. 76 | pub(crate) fn clear_cookies(this: &Arc>) -> Result<()> { 77 | with_this_mut(this, "clear_cookies", |this| { 78 | this.cookies = CookieJar::new(); 79 | }) 80 | } 81 | 82 | /// Adds the given cookies. 83 | /// 84 | /// They will be stored over the top of the existing cookies. 85 | pub(crate) fn add_cookies(this: &Arc>, cookies: CookieJar) -> Result<()> { 86 | with_this_mut(this, "add_cookies", |this| { 87 | for cookie in cookies.iter() { 88 | this.cookies.add(cookie.to_owned()); 89 | } 90 | }) 91 | } 92 | 93 | pub(crate) fn add_cookie(this: &Arc>, cookie: Cookie) -> Result<()> { 94 | with_this_mut(this, "add_cookie", |this| { 95 | this.cookies.add(cookie.into_owned()); 96 | }) 97 | } 98 | 99 | pub(crate) fn add_query_params(this: &Arc>, query_params: V) -> Result<()> 100 | where 101 | V: Serialize, 102 | { 103 | with_this_mut(this, "add_query_params", |this| { 104 | this.query_params.add(query_params) 105 | })? 106 | } 107 | 108 | pub(crate) fn add_query_param(this: &Arc>, key: &str, value: V) -> Result<()> 109 | where 110 | V: Serialize, 111 | { 112 | with_this_mut(this, "add_query_param", |this| { 113 | this.query_params.add(&[(key, value)]) 114 | })? 115 | } 116 | 117 | pub(crate) fn add_raw_query_param(this: &Arc>, raw_value: &str) -> Result<()> { 118 | with_this_mut(this, "add_raw_query_param", |this| { 119 | this.query_params.add_raw(raw_value.to_string()) 120 | }) 121 | } 122 | 123 | pub(crate) fn clear_query_params(this: &Arc>) -> Result<()> { 124 | with_this_mut(this, "clear_query_params", |this| this.query_params.clear()) 125 | } 126 | 127 | pub(crate) fn clear_headers(this: &Arc>) -> Result<()> { 128 | with_this_mut(this, "clear_headers", |this| this.headers.clear()) 129 | } 130 | 131 | pub(crate) fn add_header( 132 | this: &Arc>, 133 | name: HeaderName, 134 | value: HeaderValue, 135 | ) -> Result<()> { 136 | with_this_mut(this, "add_header", |this| this.headers.push((name, value))) 137 | } 138 | 139 | pub(crate) fn set_scheme(this: &Arc>, scheme: String) -> Result<()> { 140 | with_this_mut(this, "set_scheme", |this| this.scheme = Some(scheme)) 141 | } 142 | 143 | pub(crate) fn set_scheme_unlocked(&mut self, scheme: String) { 144 | self.scheme = Some(scheme); 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/into_make_service_with_connect_info.rs: -------------------------------------------------------------------------------- 1 | use crate::internals::HttpTransportLayer; 2 | use crate::transport_layer::IntoTransportLayer; 3 | use crate::transport_layer::TransportLayer; 4 | use crate::transport_layer::TransportLayerBuilder; 5 | use crate::util::spawn_serve; 6 | use anyhow::Result; 7 | use anyhow::anyhow; 8 | use axum::extract::Request as AxumRequest; 9 | use axum::extract::connect_info::IntoMakeServiceWithConnectInfo; 10 | use axum::response::Response as AxumResponse; 11 | use axum::serve::IncomingStream; 12 | use std::convert::Infallible; 13 | use tokio::net::TcpListener; 14 | use tower::Service; 15 | use url::Url; 16 | 17 | impl IntoTransportLayer for IntoMakeServiceWithConnectInfo 18 | where 19 | for<'a> C: axum::extract::connect_info::Connected>, 20 | S: Service + Clone + Send + 'static, 21 | S::Future: Send, 22 | { 23 | fn into_http_transport_layer( 24 | self, 25 | builder: TransportLayerBuilder, 26 | ) -> Result> { 27 | let (socket_addr, tcp_listener, maybe_reserved_port) = 28 | builder.tcp_listener_with_reserved_port()?; 29 | 30 | let serve_handle = spawn_serve(tcp_listener, self); 31 | let server_address = format!("http://{socket_addr}"); 32 | let server_url: Url = server_address.parse()?; 33 | 34 | Ok(Box::new(HttpTransportLayer::new( 35 | serve_handle, 36 | maybe_reserved_port, 37 | server_url, 38 | ))) 39 | } 40 | 41 | fn into_mock_transport_layer(self) -> Result> { 42 | Err(anyhow!( 43 | "`IntoMakeServiceWithConnectInfo` cannot be mocked, as it's underlying implementation requires a real connection. Set the `TestServerConfig` to run with a transport of `HttpRandomPort`, or a `HttpIpPort`." 44 | )) 45 | } 46 | 47 | fn into_default_transport( 48 | self, 49 | builder: TransportLayerBuilder, 50 | ) -> Result> { 51 | self.into_http_transport_layer(builder) 52 | } 53 | } 54 | 55 | #[cfg(test)] 56 | mod test_into_http_transport_layer_for_into_make_service_with_connect_info { 57 | use crate::TestServer; 58 | use axum::Router; 59 | use axum::ServiceExt; 60 | use axum::extract::Request; 61 | use axum::routing::get; 62 | use std::net::SocketAddr; 63 | use tower::Layer; 64 | use tower_http::normalize_path::NormalizePathLayer; 65 | 66 | async fn get_ping() -> &'static str { 67 | "pong!" 68 | } 69 | 70 | #[tokio::test] 71 | async fn it_should_create_and_test_with_make_into_service_with_connect_info() { 72 | // Build an application with a route. 73 | let app = Router::new() 74 | .route("/ping", get(get_ping)) 75 | .into_make_service_with_connect_info::(); 76 | 77 | // Run the server. 78 | let server = TestServer::builder() 79 | .http_transport() 80 | .build(app) 81 | .expect("Should create test server"); 82 | 83 | // Get the request. 84 | server.get(&"/ping").await.assert_text(&"pong!"); 85 | } 86 | 87 | #[tokio::test] 88 | async fn it_should_create_and_run_with_router_wrapped_service() { 89 | // Build an application with a route. 90 | let router = Router::new().route("/ping", get(get_ping)); 91 | let normalized_router = NormalizePathLayer::trim_trailing_slash().layer(router); 92 | let app = ServiceExt::::into_make_service_with_connect_info::( 93 | normalized_router, 94 | ); 95 | 96 | // Run the server. 97 | let server = TestServer::builder() 98 | .http_transport() 99 | .build(app) 100 | .expect("Should create test server"); 101 | 102 | // Get the request. 103 | server.get(&"/ping").await.assert_text(&"pong!"); 104 | } 105 | } 106 | 107 | #[cfg(test)] 108 | mod test_into_mock_transport_layer_for_into_make_service_with_connect_info { 109 | use crate::TestServer; 110 | use axum::Router; 111 | use axum::routing::get; 112 | use std::net::SocketAddr; 113 | 114 | async fn get_ping() -> &'static str { 115 | "pong!" 116 | } 117 | 118 | #[tokio::test] 119 | async fn it_should_panic_when_creating_test_using_mock() { 120 | // Build an application with a route. 121 | let app = Router::new() 122 | .route("/ping", get(get_ping)) 123 | .into_make_service_with_connect_info::(); 124 | 125 | // Build the server. 126 | let result = TestServer::builder().mock_transport().build(app); 127 | let err = result.unwrap_err(); 128 | let err_msg = err.to_string(); 129 | 130 | assert_eq!( 131 | err_msg, 132 | "`IntoMakeServiceWithConnectInfo` cannot be mocked, as it's underlying implementation requires a real connection. Set the `TestServerConfig` to run with a transport of `HttpRandomPort`, or a `HttpIpPort`." 133 | ); 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/serve.rs: -------------------------------------------------------------------------------- 1 | use crate::internals::HttpTransportLayer; 2 | use crate::transport_layer::IntoTransportLayer; 3 | use crate::transport_layer::TransportLayer; 4 | use crate::transport_layer::TransportLayerBuilder; 5 | use crate::util::ServeHandle; 6 | use anyhow::Context; 7 | use anyhow::Result; 8 | use anyhow::anyhow; 9 | use axum::extract::Request; 10 | use axum::response::Response; 11 | use axum::serve::IncomingStream; 12 | use axum::serve::Serve; 13 | use std::convert::Infallible; 14 | use tokio::net::TcpListener; 15 | use tokio::spawn; 16 | use tower::Service; 17 | use url::Url; 18 | 19 | impl IntoTransportLayer for Serve 20 | where 21 | M: for<'a> Service, Error = Infallible, Response = S> 22 | + Send 23 | + 'static, 24 | for<'a> >>::Future: Send, 25 | S: Service + Clone + Send + 'static, 26 | S::Future: Send, 27 | { 28 | fn into_http_transport_layer( 29 | self, 30 | _builder: TransportLayerBuilder, 31 | ) -> Result> { 32 | Err(anyhow!( 33 | "`Serve` must be started with http or mock transport. Do not set any transport on `TestServerConfig`." 34 | )) 35 | } 36 | 37 | fn into_mock_transport_layer(self) -> Result> { 38 | Err(anyhow!( 39 | "`Serve` cannot be mocked, as it's underlying implementation requires a real connection. Do not set any transport on `TestServerConfig`." 40 | )) 41 | } 42 | 43 | fn into_default_transport( 44 | self, 45 | _builder: TransportLayerBuilder, 46 | ) -> Result> { 47 | let socket_addr = self.local_addr()?; 48 | 49 | let join_handle = spawn(async move { 50 | self.await 51 | .context("Failed to create ::axum::Server for TestServer") 52 | .expect("Expect server to start serving"); 53 | }); 54 | 55 | let server_address = format!("http://{socket_addr}"); 56 | let server_url: Url = server_address.parse()?; 57 | 58 | Ok(Box::new(HttpTransportLayer::new( 59 | ServeHandle::new(join_handle), 60 | None, 61 | server_url, 62 | ))) 63 | } 64 | } 65 | 66 | #[cfg(test)] 67 | mod test_into_http_transport_layer { 68 | use crate::TestServer; 69 | use crate::util::new_random_tokio_tcp_listener; 70 | use axum::Router; 71 | use axum::routing::IntoMakeService; 72 | use axum::routing::get; 73 | use axum::serve; 74 | 75 | async fn get_ping() -> &'static str { 76 | "pong!" 77 | } 78 | 79 | #[tokio::test] 80 | #[should_panic] 81 | async fn it_should_panic_when_run_with_http() { 82 | // Build an application with a route. 83 | let app: IntoMakeService = Router::new() 84 | .route("/ping", get(get_ping)) 85 | .into_make_service(); 86 | let port = new_random_tokio_tcp_listener().unwrap(); 87 | let application = serve(port, app); 88 | 89 | // Run the server. 90 | TestServer::builder() 91 | .http_transport() 92 | .build(application) 93 | .expect("Should create test server"); 94 | } 95 | } 96 | 97 | #[cfg(test)] 98 | mod test_into_mock_transport_layer { 99 | use crate::TestServer; 100 | use crate::util::new_random_tokio_tcp_listener; 101 | use axum::Router; 102 | use axum::routing::IntoMakeService; 103 | use axum::routing::get; 104 | use axum::serve; 105 | 106 | async fn get_ping() -> &'static str { 107 | "pong!" 108 | } 109 | 110 | #[tokio::test] 111 | #[should_panic] 112 | async fn it_should_panic_when_run_with_mock_http() { 113 | // Build an application with a route. 114 | let app: IntoMakeService = Router::new() 115 | .route("/ping", get(get_ping)) 116 | .into_make_service(); 117 | let port = new_random_tokio_tcp_listener().unwrap(); 118 | let application = serve(port, app); 119 | 120 | // Run the server. 121 | TestServer::builder() 122 | .mock_transport() 123 | .build(application) 124 | .expect("Should create test server"); 125 | } 126 | } 127 | 128 | #[cfg(test)] 129 | mod test_into_default_transport { 130 | use crate::TestServer; 131 | use crate::util::new_random_tokio_tcp_listener; 132 | use axum::Router; 133 | use axum::routing::IntoMakeService; 134 | use axum::routing::get; 135 | use axum::serve; 136 | 137 | async fn get_ping() -> &'static str { 138 | "pong!" 139 | } 140 | 141 | #[tokio::test] 142 | async fn it_should_run_service() { 143 | // Build an application with a route. 144 | let app: IntoMakeService = Router::new() 145 | .route("/ping", get(get_ping)) 146 | .into_make_service(); 147 | let port = new_random_tokio_tcp_listener().unwrap(); 148 | let application = serve(port, app); 149 | 150 | // Run the server. 151 | let server = TestServer::builder() 152 | .build(application) 153 | .expect("Should create test server"); 154 | 155 | // Get the request. 156 | server.get(&"/ping").await.assert_text(&"pong!"); 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/with_graceful_shutdown.rs: -------------------------------------------------------------------------------- 1 | use crate::internals::HttpTransportLayer; 2 | use crate::transport_layer::IntoTransportLayer; 3 | use crate::transport_layer::TransportLayer; 4 | use crate::transport_layer::TransportLayerBuilder; 5 | use crate::util::ServeHandle; 6 | use anyhow::Context; 7 | use anyhow::Result; 8 | use anyhow::anyhow; 9 | use axum::extract::Request; 10 | use axum::response::Response; 11 | use axum::serve::IncomingStream; 12 | use axum::serve::WithGracefulShutdown; 13 | use std::convert::Infallible; 14 | use std::future::Future; 15 | use tokio::net::TcpListener; 16 | use tokio::spawn; 17 | use tower::Service; 18 | use url::Url; 19 | 20 | impl IntoTransportLayer for WithGracefulShutdown 21 | where 22 | M: for<'a> Service, Error = Infallible, Response = S> 23 | + Send 24 | + 'static, 25 | for<'a> >>::Future: Send, 26 | S: Service + Clone + Send + 'static, 27 | S::Future: Send, 28 | F: Future + Send + 'static, 29 | { 30 | fn into_http_transport_layer( 31 | self, 32 | _builder: TransportLayerBuilder, 33 | ) -> Result> { 34 | Err(anyhow!( 35 | "`WithGracefulShutdown` must be started with http or mock transport. Do not set any transport on `TestServerConfig`." 36 | )) 37 | } 38 | 39 | fn into_mock_transport_layer(self) -> Result> { 40 | Err(anyhow!( 41 | "`WithGracefulShutdown` cannot be mocked, as it's underlying implementation requires a real connection. Do not set any transport on `TestServerConfig`." 42 | )) 43 | } 44 | 45 | fn into_default_transport( 46 | self, 47 | _builder: TransportLayerBuilder, 48 | ) -> Result> { 49 | let socket_addr = self.local_addr()?; 50 | 51 | let join_handle = spawn(async move { 52 | self.await 53 | .context("Failed to create ::axum::Server for TestServer") 54 | .expect("Expect server to start serving"); 55 | }); 56 | 57 | let server_address = format!("http://{socket_addr}"); 58 | let server_url: Url = server_address.parse()?; 59 | 60 | Ok(Box::new(HttpTransportLayer::new( 61 | ServeHandle::new(join_handle), 62 | None, 63 | server_url, 64 | ))) 65 | } 66 | } 67 | 68 | #[cfg(test)] 69 | mod test_into_http_transport_layer { 70 | use crate::TestServer; 71 | use crate::util::new_random_tokio_tcp_listener; 72 | use axum::Router; 73 | use axum::routing::IntoMakeService; 74 | use axum::routing::get; 75 | use axum::serve; 76 | use std::future::pending; 77 | 78 | async fn get_ping() -> &'static str { 79 | "pong!" 80 | } 81 | 82 | #[tokio::test] 83 | #[should_panic] 84 | async fn it_should_panic_when_run_with_http() { 85 | // Build an application with a route. 86 | let app: IntoMakeService = Router::new() 87 | .route("/ping", get(get_ping)) 88 | .into_make_service(); 89 | let port = new_random_tokio_tcp_listener().unwrap(); 90 | let application = serve(port, app).with_graceful_shutdown(pending()); 91 | 92 | // Run the server. 93 | TestServer::builder() 94 | .http_transport() 95 | .build(application) 96 | .expect("Should create test server"); 97 | } 98 | } 99 | 100 | #[cfg(test)] 101 | mod test_into_mock_transport_layer { 102 | use crate::TestServer; 103 | use crate::util::new_random_tokio_tcp_listener; 104 | use axum::Router; 105 | use axum::routing::IntoMakeService; 106 | use axum::routing::get; 107 | use axum::serve; 108 | use std::future::pending; 109 | 110 | async fn get_ping() -> &'static str { 111 | "pong!" 112 | } 113 | 114 | #[tokio::test] 115 | #[should_panic] 116 | async fn it_should_panic_when_run_with_mock_http() { 117 | // Build an application with a route. 118 | let app: IntoMakeService = Router::new() 119 | .route("/ping", get(get_ping)) 120 | .into_make_service(); 121 | let port = new_random_tokio_tcp_listener().unwrap(); 122 | let application = serve(port, app).with_graceful_shutdown(pending()); 123 | 124 | // Run the server. 125 | TestServer::builder() 126 | .mock_transport() 127 | .build(application) 128 | .expect("Should create test server"); 129 | } 130 | } 131 | 132 | #[cfg(test)] 133 | mod test_into_default_transport { 134 | use crate::TestServer; 135 | use crate::util::new_random_tokio_tcp_listener; 136 | use axum::Router; 137 | use axum::routing::IntoMakeService; 138 | use axum::routing::get; 139 | use axum::serve; 140 | use std::future::pending; 141 | 142 | async fn get_ping() -> &'static str { 143 | "pong!" 144 | } 145 | 146 | #[tokio::test] 147 | async fn it_should_run_service() { 148 | // Build an application with a route. 149 | let app: IntoMakeService = Router::new() 150 | .route("/ping", get(get_ping)) 151 | .into_make_service(); 152 | let port = new_random_tokio_tcp_listener().unwrap(); 153 | let application = serve(port, app).with_graceful_shutdown(pending()); 154 | 155 | // Run the server. 156 | let server = TestServer::builder() 157 | .build(application) 158 | .expect("Should create test server"); 159 | 160 | // Get the request. 161 | server.get(&"/ping").await.assert_text(&"pong!"); 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /src/test_server_config.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | 3 | use crate::TestServer; 4 | use crate::TestServerBuilder; 5 | use crate::Transport; 6 | use crate::transport_layer::IntoTransportLayer; 7 | 8 | /// This is for customising the [`TestServer`](crate::TestServer) on construction. 9 | /// It implements [`Default`] to ease building. 10 | /// 11 | /// ```rust 12 | /// use axum_test::TestServerConfig; 13 | /// 14 | /// let config = TestServerConfig { 15 | /// save_cookies: true, 16 | /// ..TestServerConfig::default() 17 | /// }; 18 | /// ``` 19 | /// 20 | /// These can be passed to `TestServer::new_with_config`: 21 | /// 22 | /// ```rust 23 | /// # async fn test() -> Result<(), Box> { 24 | /// # 25 | /// use axum::Router; 26 | /// use axum_test::TestServer; 27 | /// use axum_test::TestServerConfig; 28 | /// 29 | /// let my_app = Router::new(); 30 | /// 31 | /// let config = TestServerConfig { 32 | /// save_cookies: true, 33 | /// ..TestServerConfig::default() 34 | /// }; 35 | /// 36 | /// // Build the Test Server 37 | /// let server = TestServer::new_with_config(my_app, config)?; 38 | /// # 39 | /// # Ok(()) 40 | /// # } 41 | /// ``` 42 | /// 43 | #[derive(Debug, Clone, Eq, PartialEq)] 44 | pub struct TestServerConfig { 45 | /// Which transport mode to use to process requests. 46 | /// For setting if the server should use mocked http (which uses [`tower::util::Oneshot`](tower::util::Oneshot)), 47 | /// or if it should run on a named or random IP address. 48 | /// 49 | /// The default is to use mocking, apart from services built using [`axum::extract::connect_info::IntoMakeServiceWithConnectInfo`](axum::extract::connect_info::IntoMakeServiceWithConnectInfo) 50 | /// (this is because it needs a real TCP stream). 51 | pub transport: Option, 52 | 53 | /// Set for the server to save cookies that are returned, 54 | /// for use in future requests. 55 | /// 56 | /// This is useful for automatically saving session cookies (and similar) 57 | /// like a browser would do. 58 | /// 59 | /// **Defaults** to false (being turned off). 60 | pub save_cookies: bool, 61 | 62 | /// Asserts that requests made to the test server, 63 | /// will by default, 64 | /// return a status code in the 2xx range. 65 | /// 66 | /// This can be overridden on a per request basis using 67 | /// [`TestRequest::expect_failure()`](crate::TestRequest::expect_failure()). 68 | /// 69 | /// This is useful when making multiple requests at a start of test 70 | /// which you presume should always work. 71 | /// 72 | /// **Defaults** to false (being turned off). 73 | pub expect_success_by_default: bool, 74 | 75 | /// If you make a request with a 'http://' schema, 76 | /// then it will ignore the Test Server's address. 77 | /// 78 | /// For example if the test server is running at `http://localhost:1234`, 79 | /// and you make a request to `http://google.com`. 80 | /// Then the request will go to `http://google.com`. 81 | /// Ignoring the `localhost:1234` part. 82 | /// 83 | /// Turning this setting on will change this behaviour. 84 | /// 85 | /// After turning this on, the same request will go to 86 | /// `http://localhost:1234/http://google.com`. 87 | /// 88 | /// **Defaults** to false (being turned off). 89 | pub restrict_requests_with_http_schema: bool, 90 | 91 | /// Set the default content type for all requests created by the `TestServer`. 92 | /// 93 | /// This overrides the default 'best efforts' approach of requests. 94 | pub default_content_type: Option, 95 | 96 | /// Set the default scheme to use for all requests created by the `TestServer`. 97 | /// 98 | /// This overrides the default 'http'. 99 | pub default_scheme: Option, 100 | } 101 | 102 | impl TestServerConfig { 103 | /// Creates a default `TestServerConfig`. 104 | pub fn new() -> Self { 105 | Default::default() 106 | } 107 | 108 | /// This is shorthand for calling [`crate::TestServer::new_with_config`], 109 | /// and passing this config. 110 | /// 111 | /// ```rust 112 | /// # async fn test() -> Result<(), Box> { 113 | /// # 114 | /// use axum::Router; 115 | /// use axum_test::TestServer; 116 | /// use axum_test::TestServerConfig; 117 | /// 118 | /// let app = Router::new(); 119 | /// let config = TestServerConfig { 120 | /// save_cookies: true, 121 | /// default_content_type: Some("application/json".to_string()), 122 | /// ..Default::default() 123 | /// }; 124 | /// let server = TestServer::new_with_config(app, config)?; 125 | /// # 126 | /// # Ok(()) 127 | /// # } 128 | /// ``` 129 | pub fn build(self, app: A) -> Result 130 | where 131 | A: IntoTransportLayer, 132 | { 133 | TestServer::new_with_config(app, self) 134 | } 135 | } 136 | 137 | impl Default for TestServerConfig { 138 | fn default() -> Self { 139 | Self { 140 | transport: None, 141 | save_cookies: false, 142 | expect_success_by_default: false, 143 | restrict_requests_with_http_schema: false, 144 | default_content_type: None, 145 | default_scheme: None, 146 | } 147 | } 148 | } 149 | 150 | impl From for TestServerConfig { 151 | fn from(builder: TestServerBuilder) -> Self { 152 | builder.into_config() 153 | } 154 | } 155 | 156 | #[cfg(test)] 157 | mod test_scheme { 158 | use axum::Router; 159 | use axum::extract::Request; 160 | use axum::routing::get; 161 | 162 | use crate::TestServer; 163 | use crate::TestServerConfig; 164 | 165 | async fn route_get_scheme(request: Request) -> String { 166 | request.uri().scheme_str().unwrap().to_string() 167 | } 168 | 169 | #[tokio::test] 170 | async fn it_should_set_scheme_when_present_in_config() { 171 | let router = Router::new().route("/scheme", get(route_get_scheme)); 172 | 173 | let config = TestServerConfig { 174 | default_scheme: Some("https".to_string()), 175 | ..Default::default() 176 | }; 177 | let server = TestServer::new_with_config(router, config).unwrap(); 178 | 179 | server.get("/scheme").await.assert_text("https"); 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /src/multipart/part.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use bytes::Bytes; 3 | use http::HeaderName; 4 | use http::HeaderValue; 5 | use mime::Mime; 6 | use std::fmt::Debug; 7 | use std::fmt::Display; 8 | 9 | /// 10 | /// For creating a section of a MultipartForm. 11 | /// 12 | /// Use [`Part::text()`](crate::multipart::Part::text()) and [`Part::bytes()`](crate::multipart::Part::bytes()) for creating new instances. 13 | /// Then attach them to a `MultipartForm` using [`MultipartForm::add_part()`](crate::multipart::MultipartForm::add_part()). 14 | /// 15 | #[derive(Debug, Clone)] 16 | pub struct Part { 17 | pub(crate) bytes: Bytes, 18 | pub(crate) file_name: Option, 19 | pub(crate) mime_type: Mime, 20 | pub(crate) headers: Vec<(HeaderName, HeaderValue)>, 21 | } 22 | 23 | impl Part { 24 | /// Creates a new part of a multipart form, that will send text. 25 | /// 26 | /// The default mime type for this part will be `text/plain`, 27 | pub fn text(text: T) -> Self 28 | where 29 | T: Display, 30 | { 31 | let bytes = text.to_string().into_bytes().into(); 32 | 33 | Self::new(bytes, mime::TEXT_PLAIN) 34 | } 35 | 36 | /// Creates a new part of a multipart form, that will upload bytes. 37 | /// 38 | /// The default mime type for this part will be `application/octet-stream`, 39 | pub fn bytes(bytes: B) -> Self 40 | where 41 | B: Into, 42 | { 43 | Self::new(bytes.into(), mime::APPLICATION_OCTET_STREAM) 44 | } 45 | 46 | fn new(bytes: Bytes, mime_type: Mime) -> Self { 47 | Self { 48 | bytes, 49 | file_name: None, 50 | mime_type, 51 | headers: Default::default(), 52 | } 53 | } 54 | 55 | /// Sets the file name for this part of a multipart form. 56 | /// 57 | /// By default there is no filename. This will set one. 58 | pub fn file_name(mut self, file_name: T) -> Self 59 | where 60 | T: Display, 61 | { 62 | self.file_name = Some(file_name.to_string()); 63 | self 64 | } 65 | 66 | /// Sets the mime type for this part of a multipart form. 67 | /// 68 | /// The default mime type is `text/plain` or `application/octet-stream`, 69 | /// depending on how this instance was created. 70 | /// This function will replace that. 71 | pub fn mime_type(mut self, mime_type: M) -> Self 72 | where 73 | M: AsRef, 74 | { 75 | let raw_mime_type = mime_type.as_ref(); 76 | let parsed_mime_type = raw_mime_type 77 | .parse() 78 | .with_context(|| format!("Failed to parse '{raw_mime_type}' as a Mime type")) 79 | .unwrap(); 80 | 81 | self.mime_type = parsed_mime_type; 82 | 83 | self 84 | } 85 | 86 | /// Adds a header to be sent with the Part of this Multiform. 87 | /// 88 | /// ```rust 89 | /// # async fn test() -> Result<(), Box> { 90 | /// # 91 | /// use axum::Router; 92 | /// use axum_test::TestServer; 93 | /// use axum_test::multipart::MultipartForm; 94 | /// use axum_test::multipart::Part; 95 | /// 96 | /// let app = Router::new(); 97 | /// let server = TestServer::new(app)?; 98 | /// 99 | /// let readme_bytes = include_bytes!("../../README.md"); 100 | /// let readme_part = Part::bytes(readme_bytes.as_slice()) 101 | /// .file_name(&"README.md") 102 | /// // Add a header to the Part 103 | /// .add_header("x-text-category", "readme"); 104 | /// 105 | /// let multipart_form = MultipartForm::new() 106 | /// .add_part("file", readme_part); 107 | /// 108 | /// let response = server.post(&"/my-form") 109 | /// .multipart(multipart_form) 110 | /// .await; 111 | /// # 112 | /// # Ok(()) } 113 | /// ``` 114 | /// 115 | pub fn add_header(mut self, name: N, value: V) -> Self 116 | where 117 | N: TryInto, 118 | N::Error: Debug, 119 | V: TryInto, 120 | V::Error: Debug, 121 | { 122 | let header_name: HeaderName = name 123 | .try_into() 124 | .expect("Failed to convert header name to HeaderName"); 125 | let header_value: HeaderValue = value 126 | .try_into() 127 | .expect("Failed to convert header vlue to HeaderValue"); 128 | 129 | self.headers.push((header_name, header_value)); 130 | self 131 | } 132 | } 133 | 134 | #[cfg(test)] 135 | mod test_text { 136 | use super::*; 137 | 138 | #[test] 139 | fn it_should_contain_text_given() { 140 | let part = Part::text("some_text"); 141 | 142 | let output = String::from_utf8_lossy(&part.bytes); 143 | assert_eq!(output, "some_text"); 144 | } 145 | 146 | #[test] 147 | fn it_should_use_mime_type_text() { 148 | let part = Part::text("some_text"); 149 | assert_eq!(part.mime_type, mime::TEXT_PLAIN); 150 | } 151 | } 152 | 153 | #[cfg(test)] 154 | mod test_byes { 155 | use super::*; 156 | 157 | #[test] 158 | fn it_should_contain_bytes_given() { 159 | let bytes = "some_text".as_bytes(); 160 | let part = Part::bytes(bytes); 161 | 162 | let output = String::from_utf8_lossy(&part.bytes); 163 | assert_eq!(output, "some_text"); 164 | } 165 | 166 | #[test] 167 | fn it_should_use_mime_type_octet_stream() { 168 | let bytes = "some_text".as_bytes(); 169 | let part = Part::bytes(bytes); 170 | 171 | assert_eq!(part.mime_type, mime::APPLICATION_OCTET_STREAM); 172 | } 173 | } 174 | 175 | #[cfg(test)] 176 | mod test_file_name { 177 | use super::*; 178 | 179 | #[test] 180 | fn it_should_use_file_name_given() { 181 | let mut part = Part::text("some_text"); 182 | 183 | assert_eq!(part.file_name, None); 184 | part = part.file_name("my-text.txt"); 185 | assert_eq!(part.file_name, Some("my-text.txt".to_string())); 186 | } 187 | } 188 | 189 | #[cfg(test)] 190 | mod test_mime_type { 191 | use super::*; 192 | 193 | #[test] 194 | fn it_should_use_mime_type_set() { 195 | let mut part = Part::text("some_text"); 196 | 197 | assert_eq!(part.mime_type, mime::TEXT_PLAIN); 198 | part = part.mime_type("application/json"); 199 | assert_eq!(part.mime_type, mime::APPLICATION_JSON); 200 | } 201 | 202 | #[test] 203 | #[should_panic] 204 | fn it_should_error_if_invalid_mime_type() { 205 | let part = Part::text("some_text"); 206 | part.mime_type("🦊"); 207 | 208 | assert!(false); 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/into_make_service.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::extract::Request as AxumRequest; 3 | use axum::response::Response as AxumResponse; 4 | use axum::routing::IntoMakeService; 5 | use std::convert::Infallible; 6 | use tower::Service; 7 | use url::Url; 8 | 9 | use crate::internals::HttpTransportLayer; 10 | use crate::internals::MockTransportLayer; 11 | use crate::transport_layer::IntoTransportLayer; 12 | use crate::transport_layer::TransportLayer; 13 | use crate::transport_layer::TransportLayerBuilder; 14 | use crate::util::spawn_serve; 15 | 16 | impl IntoTransportLayer for IntoMakeService 17 | where 18 | S: Service 19 | + Clone 20 | + Send 21 | + Sync 22 | + 'static, 23 | S::Future: Send, 24 | { 25 | fn into_http_transport_layer( 26 | self, 27 | builder: TransportLayerBuilder, 28 | ) -> Result> { 29 | let (socket_addr, tcp_listener, maybe_reserved_port) = 30 | builder.tcp_listener_with_reserved_port()?; 31 | 32 | let serve_handle = spawn_serve(tcp_listener, self); 33 | let server_address = format!("http://{socket_addr}"); 34 | let server_url: Url = server_address.parse()?; 35 | 36 | Ok(Box::new(HttpTransportLayer::new( 37 | serve_handle, 38 | maybe_reserved_port, 39 | server_url, 40 | ))) 41 | } 42 | 43 | fn into_mock_transport_layer(self) -> Result> { 44 | let transport_layer = MockTransportLayer::new(self); 45 | Ok(Box::new(transport_layer)) 46 | } 47 | } 48 | 49 | #[cfg(test)] 50 | mod test_into_http_transport_layer_for_into_make_service { 51 | use crate::TestServer; 52 | use axum::Router; 53 | use axum::ServiceExt; 54 | use axum::extract::Request; 55 | use axum::extract::State; 56 | use axum::routing::get; 57 | use tower::Layer; 58 | use tower_http::normalize_path::NormalizePathLayer; 59 | 60 | async fn get_ping() -> &'static str { 61 | "pong!" 62 | } 63 | 64 | async fn get_state(State(count): State) -> String { 65 | format!("count is {}", count) 66 | } 67 | 68 | #[tokio::test] 69 | async fn it_should_create_and_test_with_make_into_service() { 70 | // Build an application with a route. 71 | let app = Router::new() 72 | .route("/ping", get(get_ping)) 73 | .into_make_service(); 74 | 75 | // Run the server. 76 | let server = TestServer::builder() 77 | .http_transport() 78 | .build(app) 79 | .expect("Should create test server"); 80 | 81 | // Get the request. 82 | server.get(&"/ping").await.assert_text(&"pong!"); 83 | } 84 | 85 | #[tokio::test] 86 | async fn it_should_create_and_test_with_make_into_service_with_state() { 87 | // Build an application with a route. 88 | let app = Router::new() 89 | .route("/count", get(get_state)) 90 | .with_state(123) 91 | .into_make_service(); 92 | 93 | // Run the server. 94 | let server = TestServer::builder() 95 | .http_transport() 96 | .build(app) 97 | .expect("Should create test server"); 98 | 99 | // Get the request. 100 | server.get(&"/count").await.assert_text(&"count is 123"); 101 | } 102 | 103 | #[tokio::test] 104 | async fn it_should_create_and_run_with_router_wrapped_service() { 105 | // Build an application with a route. 106 | let router = Router::new() 107 | .route("/count", get(get_state)) 108 | .with_state(123); 109 | let normalized_router = NormalizePathLayer::trim_trailing_slash().layer(router); 110 | let app = ServiceExt::::into_make_service(normalized_router); 111 | 112 | // Run the server. 113 | let server = TestServer::builder() 114 | .http_transport() 115 | .build(app) 116 | .expect("Should create test server"); 117 | 118 | // Get the request. 119 | server.get(&"/count").await.assert_text(&"count is 123"); 120 | } 121 | } 122 | 123 | #[cfg(test)] 124 | mod test_into_mock_transport_layer_for_into_make_service { 125 | use crate::TestServer; 126 | use axum::Router; 127 | use axum::ServiceExt; 128 | use axum::extract::Request; 129 | use axum::extract::State; 130 | use axum::routing::get; 131 | use tower::Layer; 132 | use tower_http::normalize_path::NormalizePathLayer; 133 | 134 | async fn get_ping() -> &'static str { 135 | "pong!" 136 | } 137 | 138 | async fn get_state(State(count): State) -> String { 139 | format!("count is {}", count) 140 | } 141 | 142 | #[tokio::test] 143 | async fn it_should_create_and_test_with_make_into_service() { 144 | // Build an application with a route. 145 | let app = Router::new() 146 | .route("/ping", get(get_ping)) 147 | .into_make_service(); 148 | 149 | // Run the server. 150 | let server = TestServer::builder() 151 | .mock_transport() 152 | .build(app) 153 | .expect("Should create test server"); 154 | 155 | // Get the request. 156 | server.get(&"/ping").await.assert_text(&"pong!"); 157 | } 158 | 159 | #[tokio::test] 160 | async fn it_should_create_and_test_with_make_into_service_with_state() { 161 | // Build an application with a route. 162 | let app = Router::new() 163 | .route("/count", get(get_state)) 164 | .with_state(123) 165 | .into_make_service(); 166 | 167 | // Run the server. 168 | let server = TestServer::builder() 169 | .mock_transport() 170 | .build(app) 171 | .expect("Should create test server"); 172 | 173 | // Get the request. 174 | server.get(&"/count").await.assert_text(&"count is 123"); 175 | } 176 | 177 | #[tokio::test] 178 | async fn it_should_create_and_run_with_router_wrapped_service() { 179 | // Build an application with a route. 180 | let router = Router::new() 181 | .route("/count", get(get_state)) 182 | .with_state(123); 183 | let normalized_router = NormalizePathLayer::trim_trailing_slash().layer(router); 184 | let app = ServiceExt::::into_make_service(normalized_router); 185 | 186 | // Run the server. 187 | let server = TestServer::builder() 188 | .mock_transport() 189 | .build(app) 190 | .expect("Should create test server"); 191 | 192 | // Get the request. 193 | server.get(&"/count").await.assert_text(&"count is 123"); 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Axum Test 4 |

5 | 6 |

7 | Easy E2E testing for Axum
8 | including REST, WebSockets, and more 9 |

10 | 11 | [![crate](https://img.shields.io/crates/v/axum-test.svg)](https://crates.io/crates/axum-test) 12 | [![docs](https://docs.rs/axum-test/badge.svg)](https://docs.rs/axum-test) 13 | 14 |
15 |
16 | 17 | This runs your application locally, allowing you to query against it with requests. 18 | Decode the responses, and assert what is returned. 19 | 20 | ```rust 21 | use axum::Router; 22 | use axum::routing::get; 23 | 24 | use axum_test::TestServer; 25 | 26 | #[tokio::test] 27 | async fn it_should_ping_pong() { 28 | // Build an application with a route. 29 | let app = Router::new() 30 | .route(&"/ping", get(|| async { "pong!" })); 31 | 32 | // Run the application for testing. 33 | let server = TestServer::new(app).unwrap(); 34 | 35 | // Get the request. 36 | let response = server 37 | .get("/ping") 38 | .await; 39 | 40 | // Assertions. 41 | response.assert_status_ok(); 42 | response.assert_text("pong!"); 43 | } 44 | ``` 45 | 46 | A `TestServer` enables you to run an Axum service with a mocked network, 47 | or on a random port with real network reqeusts. 48 | In both cases allowing you to run multiple servers, across multiple tests, all in parallel. 49 | 50 | 51 | ## 🎁 Crate Features 52 | 53 | | Feature | On by default | | 54 | |---------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------| 55 | | `all` | _off_ | Turns on all non-deprecated features. | 56 | | `pretty-assertions` | **on** | Uses the [pretty assertions crate](https://crates.io/crates/pretty_assertions) on response `assert_*` methods. | 57 | | `yaml` | _off_ | Enables support for sending, receiving, and asserting, [yaml content](https://yaml.org/). | 58 | | `msgpack` | _off_ | Enables support for sending, receiving, and asserting, [msgpack content](https://msgpack.org/index.html). | 59 | | `shuttle` | _off_ | Enables support for building a `TestServer` an [`shuttle_axum::AxumService`](https://docs.rs/shuttle-axum/latest/shuttle_axum/struct.AxumService.html), for use with [Shuttle.rs](https://shuttle.rs). | 60 | | `typed-routing` | _off_ | Enables support for using `TypedPath` in requests. See [axum-extra](https://crates.io/crates/axum-extra) for details. | 61 | | `ws` | _off_ | Enables WebSocket support. See [TestWebSocket](https://docs.rs/axum-test/latest/axum_test/struct.TestWebSocket.html) for details. | 62 | | `reqwest` | _off_ | Enables the `TestServer` being able to create [Reqwest](https://docs.rs/axum-test/latest/axum_test/struct.TestWebSocket.html) requests for querying. | 63 | 64 | 65 | ### Deprecated 66 | 67 | | Feature | On by default | | 68 | |---------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------| 69 | | `old-json-diff` | _off_ | Switches back to the old Json diff behaviour before Axum Test 18. If you find yourself needing this, then please raise an issue to let me know why. | 70 | 71 | 72 | ## ⚙️ Axum Compatibility 73 | 74 | The current version of Axum Test requires at least Axum v0.8.7. 75 | 76 | Here is a list of compatability with prior versions: 77 | 78 | | Axum Version | Axum Test Version | 79 | |-----------------|-------------------| 80 | | 0.8.7+ (latest) | 18.3.0 (latest) | 81 | | 0.8.4 | 18.0.0 | 82 | | 0.8.3 | 17.3 | 83 | | 0.8.0 | 17 | 84 | | 0.7.6 to 0.7.9 | 16 | 85 | | 0.7.0 to 0.7.5 | 14, 15 | 86 | | 0.6 | 13.4.1 | 87 | 88 | 89 | ## 📺 Examples 90 | 91 | You can find examples of writing tests in the [/examples folder](/examples/). 92 | These include tests for: 93 | 94 | * [a simple REST Todo application](/examples/example-todo), and [the same using Shuttle](/examples/example-shuttle) 95 | * [a WebSocket ping pong application](/examples/example-websocket-ping-pong) which sends requests up and down 96 | * [a simple WebSocket chat application](/examples/example-websocket-chat) 97 | 98 | 99 | ## 🚀 Request Building Features 100 | 101 | Querying your application on the `TestServer` supports all of the common request building you would expect. 102 | 103 | - Serializing and deserializing Json, Form, Yaml, and others, using Serde 104 | - Assertions on the Json, text, Yaml, etc, that is returned. 105 | - Cookie, query, and header setting and reading 106 | - Status code reading and assertions 107 | 108 | 109 | ### Powerful Json assertions 110 | 111 | The ability to assert only the _shape_ of the Json returned: 112 | 113 | ```rust 114 | use axum_test::TestServer; 115 | use axum_test::expect_json; 116 | use std::time::Duration; 117 | 118 | // Your application 119 | let app = Router::new() 120 | .route(&"/user/alan", get(|| async { 121 | // ... 122 | })); 123 | 124 | let server = TestServer::new(app)?; 125 | server.get(&"/user/alan") 126 | .await 127 | .assert_json(&json!({ 128 | "name": "Alan", 129 | 130 | // expect a valid UUID 131 | "id": expect_json::uuid(), 132 | 133 | // expect an adult age 134 | "age": expect_json::integer() 135 | .in_range(18..=120), 136 | 137 | // expect user to be created within the last minute 138 | "created_at": expect_json::iso_date_time() 139 | .within_past(Duration::from_secs(60)) 140 | .utc() 141 | })); 142 | ``` 143 | 144 | Docs: 145 | - [axum_test::TestResponse::assert_json](https://docs.rs/axum-test/latest/axum_test/struct.TestResponse.html#method.assert_json) 146 | - [axum_test::expect_json](https://docs.rs/axum-test/latest/axum_test/expect_json/index.html) 147 | 148 | 149 | ### Also includes 150 | 151 | - WebSockets testing support 152 | - Saving returned cookies for use on future requests 153 | - Setting headers, query, and cookies, globally for all requests or on per request basis 154 | - Can run requests using a real web server, or with mocked HTTP 155 | - Automatic status assertions for expecting requests to succeed (to help catch bugs in tests sooner) 156 | - Prettified assertion output 157 | - Typed Routing from Axum Extra 158 | - Reqwest integration 159 | 160 | 161 | ## ❤️ Contributions 162 | 163 | A big thank you to all of these who have helped! 164 | 165 |
166 | 167 | 168 | 169 | Made with [contrib.rocks](https://contrib.rocks) 170 | -------------------------------------------------------------------------------- /examples/example-websocket-chat/main.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This is an example Todo Application using Web Sockets for communication. 3 | //! 4 | //! At the bottom of this file are a series of tests for using websockets. 5 | //! 6 | //! ```bash 7 | //! # To run it's tests: 8 | //! cargo test --example=example-websocket-chat --features ws 9 | //! ``` 10 | //! 11 | 12 | use anyhow::Result; 13 | use axum::Router; 14 | use axum::extract::Path; 15 | use axum::extract::State; 16 | use axum::extract::WebSocketUpgrade; 17 | use axum::extract::ws::Message; 18 | use axum::extract::ws::WebSocket; 19 | use axum::response::Response; 20 | use axum::routing::get; 21 | use axum::serve::serve; 22 | use futures_util::SinkExt; 23 | use futures_util::StreamExt; 24 | use serde::Deserialize; 25 | use serde::Serialize; 26 | use std::collections::HashMap; 27 | use std::net::IpAddr; 28 | use std::net::Ipv4Addr; 29 | use std::net::SocketAddr; 30 | use std::sync::Arc; 31 | use std::time::Duration; 32 | use tokio::net::TcpListener; 33 | use tokio::sync::RwLock; 34 | 35 | #[cfg(test)] 36 | use axum_test::TestServer; 37 | 38 | const PORT: u16 = 8080; 39 | 40 | #[tokio::main] 41 | async fn main() { 42 | let result: Result<()> = { 43 | let app = new_app(); 44 | 45 | // Start! 46 | let ip_address = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); 47 | let address = SocketAddr::new(ip_address, PORT); 48 | let listener = TcpListener::bind(address).await.unwrap(); 49 | serve(listener, app.into_make_service()).await.unwrap(); 50 | 51 | Ok(()) 52 | }; 53 | 54 | match &result { 55 | Err(err) => eprintln!("{}", err), 56 | _ => {} 57 | }; 58 | } 59 | 60 | type SharedAppState = Arc>; 61 | 62 | /// This my poor mans chat system. 63 | /// 64 | /// It holds a map of User ID to Messages. 65 | #[derive(Debug)] 66 | pub struct AppState { 67 | user_messages: HashMap>, 68 | } 69 | 70 | #[derive(Deserialize, Serialize, Debug, PartialEq)] 71 | pub struct ChatSendMessage { 72 | pub to: String, 73 | pub message: String, 74 | } 75 | 76 | #[derive(Deserialize, Serialize, Debug, PartialEq)] 77 | pub struct ChatReceivedMessage { 78 | pub from: String, 79 | pub message: String, 80 | } 81 | 82 | pub async fn route_get_websocket_chat( 83 | State(state): State, 84 | Path(username): Path, 85 | ws: WebSocketUpgrade, 86 | ) -> Response { 87 | ws.on_upgrade(move |socket| handle_chat(socket, username, state.clone())) 88 | } 89 | 90 | async fn handle_chat(socket: WebSocket, username: String, state: SharedAppState) { 91 | let (mut sender, mut receiver) = socket.split(); 92 | 93 | // Spawn a task that will push several messages to the client (does not matter what client does) 94 | let send_state = state.clone(); 95 | let send_username = username.clone(); 96 | let mut send_task = tokio::spawn(async move { 97 | loop { 98 | let mut state_locked = send_state.write().await; 99 | let maybe_messages = state_locked.user_messages.get_mut(&send_username); 100 | 101 | if let Some(messages) = maybe_messages { 102 | while let Some(message) = messages.pop() { 103 | let json_text = serde_json::to_string(&message) 104 | .expect("Failed to build JSON message for sending"); 105 | 106 | sender 107 | .send(Message::Text(json_text.into())) 108 | .await 109 | .expect("Failed to send message to socket"); 110 | } 111 | } 112 | 113 | ::tokio::time::sleep(Duration::from_millis(10)).await; 114 | } 115 | }); 116 | 117 | // This second task will receive messages from client and print them on server console 118 | let mut recv_task = tokio::spawn(async move { 119 | while let Some(Ok(message)) = receiver.next().await { 120 | let raw_text = message 121 | .into_text() 122 | .expect("Failed to read text from incoming message"); 123 | let decoded = serde_json::from_str::(&raw_text) 124 | .expect("Failed to decode incoming JSON message"); 125 | 126 | let mut state_locked = state.write().await; 127 | let maybe_messages = state_locked.user_messages.entry(decoded.to); 128 | maybe_messages.or_default().push(ChatReceivedMessage { 129 | from: username.clone(), 130 | message: decoded.message, 131 | }); 132 | } 133 | }); 134 | 135 | // If any one of the tasks exit, abort the other. 136 | tokio::select! { 137 | rv_a = (&mut send_task) => { 138 | match rv_a { 139 | Ok(_) => println!("Messages sent"), 140 | Err(a) => println!("Error sending messages {a:?}") 141 | } 142 | recv_task.abort(); 143 | }, 144 | rv_b = (&mut recv_task) => { 145 | match rv_b { 146 | Ok(_) => println!("Received messages"), 147 | Err(b) => println!("Error receiving messages {b:?}") 148 | } 149 | send_task.abort(); 150 | } 151 | } 152 | } 153 | 154 | pub(crate) fn new_app() -> Router { 155 | let state = AppState { 156 | user_messages: HashMap::new(), 157 | }; 158 | let shared_state = Arc::new(RwLock::new(state)); 159 | 160 | Router::new() 161 | .route(&"/ws-chat/{name}", get(route_get_websocket_chat)) 162 | .with_state(shared_state) 163 | } 164 | 165 | #[cfg(test)] 166 | fn new_test_app() -> TestServer { 167 | let app = new_app(); 168 | TestServer::builder() 169 | .http_transport() // Important! It must be a HTTP Transport here. 170 | .build(app) 171 | .unwrap() 172 | } 173 | 174 | #[cfg(test)] 175 | mod test_websockets_chat { 176 | use super::*; 177 | 178 | #[tokio::test] 179 | async fn it_should_start_a_websocket_connection() { 180 | let server = new_test_app(); 181 | 182 | let response = server.get_websocket(&"/ws-chat/john").await; 183 | 184 | response.assert_status_switching_protocols(); 185 | } 186 | 187 | #[tokio::test] 188 | async fn it_should_send_messages_back_and_forth() { 189 | let server = new_test_app(); 190 | 191 | let mut alice_chat = server 192 | .get_websocket(&"/ws-chat/alice") 193 | .await 194 | .into_websocket() 195 | .await; 196 | 197 | let mut bob_chat = server 198 | .get_websocket(&"/ws-chat/bob") 199 | .await 200 | .into_websocket() 201 | .await; 202 | 203 | bob_chat 204 | .send_json(&ChatSendMessage { 205 | to: "alice".to_string(), 206 | message: "How are you Alice?".to_string(), 207 | }) 208 | .await; 209 | 210 | alice_chat 211 | .assert_receive_json(&ChatReceivedMessage { 212 | from: "bob".to_string(), 213 | message: "How are you Alice?".to_string(), 214 | }) 215 | .await; 216 | alice_chat 217 | .send_json(&ChatSendMessage { 218 | to: "bob".to_string(), 219 | message: "I am good".to_string(), 220 | }) 221 | .await; 222 | alice_chat 223 | .send_json(&ChatSendMessage { 224 | to: "bob".to_string(), 225 | message: "How are you?".to_string(), 226 | }) 227 | .await; 228 | 229 | bob_chat 230 | .assert_receive_json(&ChatReceivedMessage { 231 | from: "alice".to_string(), 232 | message: "I am good".to_string(), 233 | }) 234 | .await; 235 | bob_chat 236 | .assert_receive_json(&ChatReceivedMessage { 237 | from: "alice".to_string(), 238 | message: "How are you?".to_string(), 239 | }) 240 | .await; 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /src/test_server_builder.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use std::net::IpAddr; 3 | 4 | use crate::TestServer; 5 | use crate::TestServerConfig; 6 | use crate::Transport; 7 | use crate::transport_layer::IntoTransportLayer; 8 | 9 | /// A builder for [`crate::TestServer`]. Inside is a [`crate::TestServerConfig`], 10 | /// configured by each method, and then turn into a server by [`crate::TestServerBuilder::build`]. 11 | /// 12 | /// The recommended way to make instances is to call [`crate::TestServer::builder`]. 13 | /// 14 | /// # Creating a [`crate::TestServer`] 15 | /// 16 | /// ```rust 17 | /// # async fn test() -> Result<(), Box> { 18 | /// # 19 | /// use axum::Router; 20 | /// use axum_test::TestServerBuilder; 21 | /// 22 | /// let my_app = Router::new(); 23 | /// let server = TestServerBuilder::new() 24 | /// .save_cookies() 25 | /// .default_content_type(&"application/json") 26 | /// .build(my_app)?; 27 | /// # 28 | /// # Ok(()) 29 | /// # } 30 | /// ``` 31 | /// 32 | /// # Creating a [`crate::TestServerConfig`] 33 | /// 34 | /// ```rust 35 | /// # async fn test() -> Result<(), Box> { 36 | /// # 37 | /// use axum::Router; 38 | /// use axum_test::TestServer; 39 | /// use axum_test::TestServerBuilder; 40 | /// 41 | /// let my_app = Router::new(); 42 | /// let config = TestServerBuilder::new() 43 | /// .save_cookies() 44 | /// .default_content_type(&"application/json") 45 | /// .into_config(); 46 | /// 47 | /// // Build the Test Server 48 | /// let server = TestServer::new_with_config(my_app, config)?; 49 | /// # 50 | /// # Ok(()) 51 | /// # } 52 | /// ``` 53 | /// 54 | /// These can be passed to [`crate::TestServer::new_with_config`]. 55 | /// 56 | #[derive(Debug, Clone)] 57 | pub struct TestServerBuilder { 58 | config: TestServerConfig, 59 | } 60 | 61 | impl TestServerBuilder { 62 | /// Creates a default `TestServerBuilder`. 63 | pub fn new() -> Self { 64 | Default::default() 65 | } 66 | 67 | pub fn from_config(config: TestServerConfig) -> Self { 68 | Self { config } 69 | } 70 | 71 | pub fn http_transport(self) -> Self { 72 | self.transport(Transport::HttpRandomPort) 73 | } 74 | 75 | pub fn http_transport_with_ip_port(self, ip: Option, port: Option) -> Self { 76 | self.transport(Transport::HttpIpPort { ip, port }) 77 | } 78 | 79 | pub fn mock_transport(self) -> Self { 80 | self.transport(Transport::MockHttp) 81 | } 82 | 83 | pub fn transport(mut self, transport: Transport) -> Self { 84 | self.config.transport = Some(transport); 85 | self 86 | } 87 | 88 | pub fn save_cookies(mut self) -> Self { 89 | self.config.save_cookies = true; 90 | self 91 | } 92 | 93 | pub fn do_not_save_cookies(mut self) -> Self { 94 | self.config.save_cookies = false; 95 | self 96 | } 97 | 98 | pub fn default_content_type(mut self, content_type: &str) -> Self { 99 | self.config.default_content_type = Some(content_type.to_string()); 100 | self 101 | } 102 | 103 | pub fn default_scheme(mut self, scheme: &str) -> Self { 104 | self.config.default_scheme = Some(scheme.to_string()); 105 | self 106 | } 107 | 108 | pub fn expect_success_by_default(mut self) -> Self { 109 | self.config.expect_success_by_default = true; 110 | self 111 | } 112 | 113 | pub fn restrict_requests_with_http_schema(mut self) -> Self { 114 | self.config.restrict_requests_with_http_schema = true; 115 | self 116 | } 117 | 118 | /// For turning this into a [`crate::TestServerConfig`] object, 119 | /// with can be passed to [`crate::TestServer::new_with_config`]. 120 | /// 121 | /// ```rust 122 | /// # async fn test() -> Result<(), Box> { 123 | /// # 124 | /// use axum::Router; 125 | /// use axum_test::TestServer; 126 | /// 127 | /// let my_app = Router::new(); 128 | /// let config = TestServer::builder() 129 | /// .save_cookies() 130 | /// .default_content_type(&"application/json") 131 | /// .into_config(); 132 | /// 133 | /// // Build the Test Server 134 | /// let server = TestServer::new_with_config(my_app, config)?; 135 | /// # 136 | /// # Ok(()) 137 | /// # } 138 | /// ``` 139 | pub fn into_config(self) -> TestServerConfig { 140 | self.config 141 | } 142 | 143 | /// Creates a new [`crate::TestServer`], running the application given, 144 | /// and with all settings from this `TestServerBuilder` applied. 145 | /// 146 | /// ```rust 147 | /// use axum::Router; 148 | /// use axum_test::TestServer; 149 | /// 150 | /// let app = Router::new(); 151 | /// let server = TestServer::builder() 152 | /// .save_cookies() 153 | /// .default_content_type(&"application/json") 154 | /// .build(app); 155 | /// ``` 156 | /// 157 | /// This is the equivalent to building [`crate::TestServerConfig`] yourself, 158 | /// and calling [`crate::TestServer::new_with_config`]. 159 | pub fn build(self, app: A) -> Result 160 | where 161 | A: IntoTransportLayer, 162 | { 163 | self.into_config().build(app) 164 | } 165 | } 166 | 167 | impl Default for TestServerBuilder { 168 | fn default() -> Self { 169 | Self { 170 | config: TestServerConfig::default(), 171 | } 172 | } 173 | } 174 | 175 | impl From for TestServerBuilder { 176 | fn from(config: TestServerConfig) -> Self { 177 | TestServerBuilder::from_config(config) 178 | } 179 | } 180 | 181 | #[cfg(test)] 182 | mod test_build { 183 | use super::*; 184 | use std::net::Ipv4Addr; 185 | 186 | #[test] 187 | fn it_should_build_default_config_by_default() { 188 | let config = TestServer::builder().into_config(); 189 | let expected = TestServerConfig::default(); 190 | 191 | assert_eq!(config, expected); 192 | } 193 | 194 | #[test] 195 | fn it_should_save_cookies_when_set() { 196 | let config = TestServer::builder().save_cookies().into_config(); 197 | 198 | assert_eq!(config.save_cookies, true); 199 | } 200 | 201 | #[test] 202 | fn it_should_not_save_cookies_when_set() { 203 | let config = TestServer::builder().do_not_save_cookies().into_config(); 204 | 205 | assert_eq!(config.save_cookies, false); 206 | } 207 | 208 | #[test] 209 | fn it_should_mock_transport_when_set() { 210 | let config = TestServer::builder().mock_transport().into_config(); 211 | 212 | assert_eq!(config.transport, Some(Transport::MockHttp)); 213 | } 214 | 215 | #[test] 216 | fn it_should_use_random_http_transport_when_set() { 217 | let config = TestServer::builder().http_transport().into_config(); 218 | 219 | assert_eq!(config.transport, Some(Transport::HttpRandomPort)); 220 | } 221 | 222 | #[test] 223 | fn it_should_use_http_transport_with_ip_port_when_set() { 224 | let config = TestServer::builder() 225 | .http_transport_with_ip_port(Some(IpAddr::V4(Ipv4Addr::new(123, 4, 5, 6))), Some(987)) 226 | .into_config(); 227 | 228 | assert_eq!( 229 | config.transport, 230 | Some(Transport::HttpIpPort { 231 | ip: Some(IpAddr::V4(Ipv4Addr::new(123, 4, 5, 6))), 232 | port: Some(987), 233 | }) 234 | ); 235 | } 236 | 237 | #[test] 238 | fn it_should_set_default_content_type_when_set() { 239 | let config = TestServer::builder() 240 | .default_content_type("text/csv") 241 | .into_config(); 242 | 243 | assert_eq!(config.default_content_type, Some("text/csv".to_string())); 244 | } 245 | 246 | #[test] 247 | fn it_should_set_default_scheme_when_set() { 248 | let config = TestServer::builder().default_scheme("ftps").into_config(); 249 | 250 | assert_eq!(config.default_scheme, Some("ftps".to_string())); 251 | } 252 | 253 | #[test] 254 | fn it_should_set_expect_success_by_default_when_set() { 255 | let config = TestServer::builder() 256 | .expect_success_by_default() 257 | .into_config(); 258 | 259 | assert_eq!(config.expect_success_by_default, true); 260 | } 261 | 262 | #[test] 263 | fn it_should_set_restrict_requests_with_http_schema_when_set() { 264 | let config = TestServer::builder() 265 | .restrict_requests_with_http_schema() 266 | .into_config(); 267 | 268 | assert_eq!(config.restrict_requests_with_http_schema, true); 269 | } 270 | } 271 | -------------------------------------------------------------------------------- /src/internals/debug_response_body.rs: -------------------------------------------------------------------------------- 1 | use crate::TestResponse; 2 | use bytesize::ByteSize; 3 | use std::fmt::Display; 4 | use std::fmt::Formatter; 5 | use std::fmt::Result as FmtResult; 6 | 7 | /// An arbituary limit to avoid printing gigabytes to the terminal. 8 | const MAX_TEXT_PRINT_LEN: usize = 10_000; 9 | 10 | #[derive(Debug)] 11 | pub struct DebugResponseBody<'a>(pub &'a TestResponse); 12 | 13 | impl Display for DebugResponseBody<'_> { 14 | fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { 15 | match self.0.maybe_content_type() { 16 | Some(content_type) => { 17 | match content_type.as_str() { 18 | // Json 19 | "application/json" | "text/json" => write_json(f, self.0), 20 | 21 | // Msgpack 22 | "application/msgpack" => write!(f, ""), 23 | 24 | // Yaml 25 | #[cfg(feature = "yaml")] 26 | "application/yaml" | "application/x-yaml" | "text/yaml" => { 27 | write_yaml(f, self.0) 28 | } 29 | 30 | #[cfg(not(feature = "yaml"))] 31 | "application/yaml" | "application/x-yaml" | "text/yaml" => { 32 | write_text(f, &self.0.text()) 33 | } 34 | 35 | // Text Content 36 | s if s.starts_with("text/") => write_text(f, &self.0.text()), 37 | 38 | // Byte Streams 39 | "application/octet-stream" => { 40 | let len = self.0.as_bytes().len(); 41 | write!(f, "", ByteSize(len as u64)) 42 | } 43 | 44 | // Unknown content type 45 | _ => { 46 | let len = self.0.as_bytes().len(); 47 | write!( 48 | f, 49 | "", 50 | ByteSize(len as u64) 51 | ) 52 | } 53 | } 54 | } 55 | 56 | // We just default to text 57 | _ => write_text(f, &self.0.text()), 58 | } 59 | } 60 | } 61 | 62 | fn write_text(f: &mut Formatter<'_>, text: &str) -> FmtResult { 63 | let len = text.len(); 64 | 65 | if len < MAX_TEXT_PRINT_LEN { 66 | write!(f, "'{text}'") 67 | } else { 68 | let text_start = text.chars().take(MAX_TEXT_PRINT_LEN); 69 | write!(f, "'")?; 70 | for c in text_start { 71 | write!(f, "{c}")?; 72 | } 73 | write!(f, "...'")?; 74 | 75 | Ok(()) 76 | } 77 | } 78 | 79 | fn write_json(f: &mut Formatter<'_>, response: &TestResponse) -> FmtResult { 80 | let bytes = response.as_bytes(); 81 | let result = serde_json::from_slice::(bytes); 82 | 83 | match result { 84 | Err(_) => { 85 | write!( 86 | f, 87 | "!!! YOUR JSON IS MALFORMED !!!\nBody: '{}'", 88 | response.text() 89 | ) 90 | } 91 | Ok(body) => { 92 | let pretty_raw = serde_json::to_string_pretty(&body) 93 | .expect("Failed to reserialise serde_json::Value of request body"); 94 | write!(f, "{pretty_raw}") 95 | } 96 | } 97 | } 98 | 99 | #[cfg(feature = "yaml")] 100 | fn write_yaml(f: &mut Formatter<'_>, response: &TestResponse) -> FmtResult { 101 | let response_bytes = response.as_bytes(); 102 | let result = serde_yaml::from_slice::(response_bytes); 103 | 104 | match result { 105 | Err(_) => { 106 | write!( 107 | f, 108 | "!!! YOUR YAML IS MALFORMED !!!\nBody: '{}'", 109 | response.text() 110 | ) 111 | } 112 | Ok(body) => { 113 | let pretty_raw = serde_yaml::to_string(&body) 114 | .expect("Failed to reserialise serde_yaml::Value of request body"); 115 | write!(f, "{pretty_raw}") 116 | } 117 | } 118 | } 119 | 120 | #[cfg(test)] 121 | mod test_fmt { 122 | use super::*; 123 | use crate::TestServer; 124 | use axum::Json; 125 | use axum::Router; 126 | use axum::body::Body; 127 | use axum::response::IntoResponse; 128 | use axum::response::Response; 129 | use axum::routing::get; 130 | use http::HeaderValue; 131 | use http::header; 132 | use pretty_assertions::assert_eq; 133 | use serde::Deserialize; 134 | use serde::Serialize; 135 | 136 | #[derive(Serialize, Deserialize, PartialEq, Debug)] 137 | struct ExampleResponse { 138 | name: String, 139 | age: u32, 140 | } 141 | 142 | #[tokio::test] 143 | async fn it_should_display_text_response_as_text() { 144 | let router = Router::new().route("/text", get(|| async { "Blah blah" })); 145 | let response = TestServer::new(router).unwrap().get("/text").await; 146 | 147 | let debug_body = DebugResponseBody(&response); 148 | let output = format!("{debug_body}"); 149 | 150 | assert_eq!(output, "'Blah blah'"); 151 | } 152 | 153 | #[tokio::test] 154 | async fn it_should_cutoff_very_long_text() { 155 | let router = Router::new().route( 156 | "/text", 157 | get(|| async { 158 | let max_len = MAX_TEXT_PRINT_LEN + 100; 159 | (0..max_len).map(|_| "🦊").collect::() 160 | }), 161 | ); 162 | let response = TestServer::new(router).unwrap().get("/text").await; 163 | 164 | let debug_body = DebugResponseBody(&response); 165 | let output = format!("{debug_body}"); 166 | 167 | let expected_content = (0..MAX_TEXT_PRINT_LEN).map(|_| "🦊").collect::(); 168 | let expected = format!("'{expected_content}...'"); 169 | 170 | assert_eq!(output, expected); 171 | } 172 | 173 | #[tokio::test] 174 | async fn it_should_pretty_print_json() { 175 | let router = Router::new().route( 176 | "/json", 177 | get(|| async { 178 | Json(ExampleResponse { 179 | name: "Joe".to_string(), 180 | age: 20, 181 | }) 182 | }), 183 | ); 184 | let response = TestServer::new(router).unwrap().get("/json").await; 185 | 186 | let debug_body = DebugResponseBody(&response); 187 | let output = format!("{debug_body}"); 188 | let expected = r###"{ 189 | "age": 20, 190 | "name": "Joe" 191 | }"###; 192 | 193 | assert_eq!(output, expected); 194 | } 195 | 196 | #[tokio::test] 197 | async fn it_should_warn_malformed_json() { 198 | let router = Router::new().route( 199 | "/json", 200 | get(|| async { 201 | let body = Body::new(r###"{ "name": "Joe" "###.to_string()); 202 | 203 | Response::builder() 204 | .header( 205 | header::CONTENT_TYPE, 206 | HeaderValue::from_static("application/json"), 207 | ) 208 | .body(body) 209 | .unwrap() 210 | .into_response() 211 | }), 212 | ); 213 | let response = TestServer::new(router).unwrap().get("/json").await; 214 | 215 | let debug_body = DebugResponseBody(&response); 216 | let output = format!("{debug_body}"); 217 | let expected = r###"!!! YOUR JSON IS MALFORMED !!! 218 | Body: '{ "name": "Joe" '"###; 219 | 220 | assert_eq!(output, expected); 221 | } 222 | 223 | #[cfg(feature = "yaml")] 224 | #[tokio::test] 225 | async fn it_should_pretty_print_yaml() { 226 | use axum_yaml::Yaml; 227 | 228 | let router = Router::new().route( 229 | "/yaml", 230 | get(|| async { 231 | Yaml(ExampleResponse { 232 | name: "Joe".to_string(), 233 | age: 20, 234 | }) 235 | }), 236 | ); 237 | let response = TestServer::new(router).unwrap().get("/yaml").await; 238 | 239 | let debug_body = DebugResponseBody(&response); 240 | let output = format!("{debug_body}"); 241 | let expected = r###"name: Joe 242 | age: 20 243 | "###; 244 | 245 | assert_eq!(output, expected); 246 | } 247 | 248 | #[cfg(feature = "yaml")] 249 | #[tokio::test] 250 | async fn it_should_warn_on_malformed_yaml() { 251 | let router = Router::new().route( 252 | "/yaml", 253 | get(|| async { 254 | let body = Body::new("🦊 🦊 🦊: : :🦊 🦊 🦊".to_string()); 255 | 256 | Response::builder() 257 | .header( 258 | header::CONTENT_TYPE, 259 | HeaderValue::from_static("application/yaml"), 260 | ) 261 | .body(body) 262 | .unwrap() 263 | .into_response() 264 | }), 265 | ); 266 | let response = TestServer::new(router).unwrap().get("/yaml").await; 267 | 268 | let debug_body = DebugResponseBody(&response); 269 | let output = format!("{debug_body}"); 270 | let expected = r###"!!! YOUR YAML IS MALFORMED !!! 271 | Body: '🦊 🦊 🦊: : :🦊 🦊 🦊'"###; 272 | 273 | assert_eq!(output, expected); 274 | } 275 | } 276 | -------------------------------------------------------------------------------- /examples/example-shuttle/main.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This is an example Todo Application, wrapped with Shuttle. 3 | //! To show some simple tests when using Shuttle + Axum. 4 | //! 5 | //! ```bash 6 | //! # To run it's tests: 7 | //! cargo test --example=example-shuttle --features shuttle 8 | //! ``` 9 | //! 10 | //! The app includes the end points for ... 11 | //! 12 | //! - POST /login ... this takes an email, and returns a session cookie. 13 | //! - PUT /todo ... once logged in, one can store todos. 14 | //! - GET /todo ... once logged in, you can retrieve all todos you have stored. 15 | //! 16 | //! At the bottom of this file are a series of tests for these endpoints. 17 | //! 18 | 19 | use anyhow::Result; 20 | use anyhow::anyhow; 21 | use axum::Router; 22 | use axum::extract::Json; 23 | use axum::extract::State; 24 | use axum::routing::get; 25 | use axum::routing::post; 26 | use axum::routing::put; 27 | use axum_extra::extract::cookie::Cookie; 28 | use axum_extra::extract::cookie::CookieJar; 29 | use http::StatusCode; 30 | use serde::Deserialize; 31 | use serde::Serialize; 32 | use serde_email::Email; 33 | use std::collections::HashMap; 34 | use std::result::Result as StdResult; 35 | use std::sync::Arc; 36 | use std::sync::RwLock; 37 | 38 | #[cfg(test)] 39 | use axum_test::TestServer; 40 | 41 | /// Main to start Shuttle application 42 | #[shuttle_runtime::main] 43 | async fn main() -> ::shuttle_axum::ShuttleAxum { 44 | new_app() 45 | } 46 | 47 | /// The Shuttle application itself 48 | fn new_app() -> ::shuttle_axum::ShuttleAxum { 49 | let state = AppState { 50 | user_todos: HashMap::new(), 51 | }; 52 | let shared_state = Arc::new(RwLock::new(state)); 53 | 54 | let app = Router::new() 55 | .route(&"/login", post(route_post_user_login)) 56 | .route(&"/todo", get(route_get_user_todos)) 57 | .route(&"/todo", put(route_put_user_todos)) 58 | .with_state(shared_state); 59 | 60 | Ok(app.into()) 61 | } 62 | 63 | /// A TestServer that runs the Shuttle application 64 | #[cfg(test)] 65 | fn new_test_app() -> TestServer { 66 | TestServer::builder() 67 | // Preserve cookies across requests 68 | // for the session cookie to work. 69 | .save_cookies() 70 | .expect_success_by_default() 71 | .mock_transport() 72 | .build(new_app()) // <- here the application is passed in 73 | .unwrap() 74 | } 75 | 76 | const USER_ID_COOKIE_NAME: &'static str = &"example-shuttle-user-id"; 77 | 78 | type SharedAppState = Arc>; 79 | 80 | // This my poor mans in memory DB. 81 | #[derive(Debug)] 82 | pub struct AppState { 83 | user_todos: HashMap>, 84 | } 85 | 86 | #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] 87 | pub struct Todo { 88 | name: String, 89 | content: String, 90 | } 91 | 92 | #[derive(Debug, Clone, Deserialize, Serialize)] 93 | pub struct LoginRequest { 94 | user: Email, 95 | } 96 | 97 | #[derive(Debug, Clone, Deserialize, Serialize)] 98 | pub struct AllTodos { 99 | todos: Vec, 100 | } 101 | 102 | #[derive(Debug, Clone, Deserialize, Serialize)] 103 | pub struct NumTodos { 104 | num: u32, 105 | } 106 | 107 | // Note you should never do something like this in a real application 108 | // for session cookies. It's really bad. Like _seriously_ bad. 109 | // 110 | // This is done like this here to keep the code shorter. That's all. 111 | fn get_user_id_from_cookie(cookies: &CookieJar) -> Result { 112 | cookies 113 | .get(&USER_ID_COOKIE_NAME) 114 | .map(|c| c.value().to_string().parse::().ok()) 115 | .flatten() 116 | .ok_or_else(|| anyhow!("id not found")) 117 | } 118 | 119 | pub async fn route_post_user_login( 120 | State(ref mut state): State, 121 | mut cookies: CookieJar, 122 | Json(_body): Json, 123 | ) -> CookieJar { 124 | let mut lock = state.write().unwrap(); 125 | let user_todos = &mut lock.user_todos; 126 | let user_id = user_todos.len() as u32; 127 | user_todos.insert(user_id, vec![]); 128 | 129 | let really_insecure_login_cookie = Cookie::new(USER_ID_COOKIE_NAME, user_id.to_string()); 130 | cookies = cookies.add(really_insecure_login_cookie); 131 | 132 | cookies 133 | } 134 | 135 | pub async fn route_put_user_todos( 136 | State(ref mut state): State, 137 | mut cookies: CookieJar, 138 | Json(todo): Json, 139 | ) -> StdResult, StatusCode> { 140 | let user_id = get_user_id_from_cookie(&mut cookies).map_err(|_| StatusCode::UNAUTHORIZED)?; 141 | 142 | let mut lock = state.write().unwrap(); 143 | let todos = lock.user_todos.get_mut(&user_id).unwrap(); 144 | 145 | todos.push(todo); 146 | let num_todos = todos.len() as u32; 147 | 148 | Ok(Json(num_todos)) 149 | } 150 | 151 | pub async fn route_get_user_todos( 152 | State(ref state): State, 153 | mut cookies: CookieJar, 154 | ) -> StdResult>, StatusCode> { 155 | let user_id = get_user_id_from_cookie(&mut cookies).map_err(|_| StatusCode::UNAUTHORIZED)?; 156 | 157 | let lock = state.read().unwrap(); 158 | let todos = lock.user_todos[&user_id].clone(); 159 | 160 | Ok(Json(todos)) 161 | } 162 | 163 | #[cfg(test)] 164 | mod test_post_login { 165 | use super::*; 166 | 167 | use serde_json::json; 168 | 169 | #[tokio::test] 170 | async fn it_should_create_session_on_login() { 171 | let server = new_test_app(); 172 | 173 | let response = server 174 | .post(&"/login") 175 | .json(&json!({ 176 | "user": "my-login@example.com", 177 | })) 178 | .await; 179 | 180 | let session_cookie = response.cookie(&USER_ID_COOKIE_NAME); 181 | assert_ne!(session_cookie.value(), ""); 182 | } 183 | 184 | #[tokio::test] 185 | async fn it_should_not_login_using_non_email() { 186 | let server = new_test_app(); 187 | 188 | let response = server 189 | .post(&"/login") 190 | .json(&json!({ 191 | "user": "blah blah blah", 192 | })) 193 | .expect_failure() 194 | .await; 195 | 196 | // There should not be a session created. 197 | let cookie = response.maybe_cookie(&USER_ID_COOKIE_NAME); 198 | assert!(cookie.is_none()); 199 | } 200 | } 201 | 202 | #[cfg(test)] 203 | mod test_route_put_user_todos { 204 | use super::*; 205 | 206 | use serde_json::json; 207 | 208 | #[tokio::test] 209 | async fn it_should_not_store_todos_without_login() { 210 | let server = new_test_app(); 211 | 212 | let response = server 213 | .put(&"/todo") 214 | .json(&json!({ 215 | "name": "shopping", 216 | "content": "buy eggs", 217 | })) 218 | .expect_failure() 219 | .await; 220 | 221 | assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED); 222 | } 223 | 224 | #[tokio::test] 225 | async fn it_should_return_number_of_todos_as_more_are_pushed() { 226 | let server = new_test_app(); 227 | 228 | server 229 | .post(&"/login") 230 | .json(&json!({ 231 | "user": "my-login@example.com", 232 | })) 233 | .await; 234 | 235 | let num_todos = server 236 | .put(&"/todo") 237 | .json(&json!({ 238 | "name": "shopping", 239 | "content": "buy eggs", 240 | })) 241 | .await 242 | .json::(); 243 | assert_eq!(num_todos, 1); 244 | 245 | let num_todos = server 246 | .put(&"/todo") 247 | .json(&json!({ 248 | "name": "afternoon", 249 | "content": "buy shoes", 250 | })) 251 | .await 252 | .json::(); 253 | assert_eq!(num_todos, 2); 254 | } 255 | } 256 | 257 | #[cfg(test)] 258 | mod test_route_get_user_todos { 259 | use super::*; 260 | 261 | use serde_json::json; 262 | 263 | #[tokio::test] 264 | async fn it_should_not_return_todos_if_logged_out() { 265 | let server = new_test_app(); 266 | 267 | let response = server 268 | .put(&"/todo") 269 | .json(&json!({ 270 | "name": "shopping", 271 | "content": "buy eggs", 272 | })) 273 | .expect_failure() 274 | .await; 275 | 276 | assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED); 277 | } 278 | 279 | #[tokio::test] 280 | async fn it_should_return_all_todos_when_logged_in() { 281 | let server = new_test_app(); 282 | 283 | server 284 | .post(&"/login") 285 | .json(&json!({ 286 | "user": "my-login@example.com", 287 | })) 288 | .await; 289 | 290 | // Push two todos. 291 | server 292 | .put(&"/todo") 293 | .json(&json!({ 294 | "name": "shopping", 295 | "content": "buy eggs", 296 | })) 297 | .await; 298 | server 299 | .put(&"/todo") 300 | .json(&json!({ 301 | "name": "afternoon", 302 | "content": "buy shoes", 303 | })) 304 | .await; 305 | 306 | // Get all todos out from the server. 307 | let todos = server.get(&"/todo").await.json::>(); 308 | 309 | let expected_todos: Vec = vec![ 310 | Todo { 311 | name: "shopping".to_string(), 312 | content: "buy eggs".to_string(), 313 | }, 314 | Todo { 315 | name: "afternoon".to_string(), 316 | content: "buy shoes".to_string(), 317 | }, 318 | ]; 319 | assert_eq!(todos, expected_todos) 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /examples/example-todo/main.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This is an example Todo Application to show some simple tests. 3 | //! 4 | //! ```bash 5 | //! # To run it's tests: 6 | //! cargo test --example=example-todo 7 | //! ``` 8 | //! 9 | //! The app includes the end points for ... 10 | //! 11 | //! - POST /login ... this takes an email, and returns a session cookie. 12 | //! - PUT /todo ... once logged in, one can store todos. 13 | //! - GET /todo ... once logged in, you can retrieve all todos you have stored. 14 | //! 15 | //! At the bottom of this file are a series of tests for these endpoints. 16 | //! 17 | 18 | use anyhow::Result; 19 | use anyhow::anyhow; 20 | use axum::Router; 21 | use axum::extract::Json; 22 | use axum::extract::State; 23 | use axum::routing::get; 24 | use axum::routing::post; 25 | use axum::routing::put; 26 | use axum::serve::serve; 27 | use axum_extra::extract::cookie::Cookie; 28 | use axum_extra::extract::cookie::CookieJar; 29 | use http::StatusCode; 30 | use serde::Deserialize; 31 | use serde::Serialize; 32 | use serde_email::Email; 33 | use std::collections::HashMap; 34 | use std::net::IpAddr; 35 | use std::net::Ipv4Addr; 36 | use std::net::SocketAddr; 37 | use std::result::Result as StdResult; 38 | use std::sync::Arc; 39 | use std::sync::RwLock; 40 | use tokio::net::TcpListener; 41 | 42 | #[cfg(test)] 43 | use axum_test::TestServer; 44 | 45 | const PORT: u16 = 8080; 46 | const USER_ID_COOKIE_NAME: &'static str = &"example-todo-user-id"; 47 | 48 | #[tokio::main] 49 | async fn main() { 50 | let result: Result<()> = { 51 | let app = new_app(); 52 | 53 | // Start! 54 | let ip_address = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); 55 | let address = SocketAddr::new(ip_address, PORT); 56 | let listener = TcpListener::bind(address).await.unwrap(); 57 | serve(listener, app.into_make_service()).await.unwrap(); 58 | 59 | Ok(()) 60 | }; 61 | 62 | match &result { 63 | Err(err) => eprintln!("{}", err), 64 | _ => {} 65 | }; 66 | } 67 | 68 | type SharedAppState = Arc>; 69 | 70 | // This my poor mans in memory DB. 71 | #[derive(Debug)] 72 | pub struct AppState { 73 | user_todos: HashMap>, 74 | } 75 | 76 | #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] 77 | pub struct Todo { 78 | name: String, 79 | content: String, 80 | } 81 | 82 | #[derive(Debug, Clone, Deserialize, Serialize)] 83 | pub struct LoginRequest { 84 | user: Email, 85 | } 86 | 87 | #[derive(Debug, Clone, Deserialize, Serialize)] 88 | pub struct AllTodos { 89 | todos: Vec, 90 | } 91 | 92 | #[derive(Debug, Clone, Deserialize, Serialize)] 93 | pub struct NumTodos { 94 | num: u32, 95 | } 96 | 97 | // Note you should never do something like this in a real application 98 | // for session cookies. It's really bad. Like _seriously_ bad. 99 | // 100 | // This is done like this here to keep the code shorter. That's all. 101 | fn get_user_id_from_cookie(cookies: &CookieJar) -> Result { 102 | cookies 103 | .get(&USER_ID_COOKIE_NAME) 104 | .map(|c| c.value().to_string().parse::().ok()) 105 | .flatten() 106 | .ok_or_else(|| anyhow!("id not found")) 107 | } 108 | 109 | pub async fn route_post_user_login( 110 | State(ref mut state): State, 111 | mut cookies: CookieJar, 112 | Json(_body): Json, 113 | ) -> CookieJar { 114 | let mut lock = state.write().unwrap(); 115 | let user_todos = &mut lock.user_todos; 116 | let user_id = user_todos.len() as u32; 117 | user_todos.insert(user_id, vec![]); 118 | 119 | let really_insecure_login_cookie = Cookie::new(USER_ID_COOKIE_NAME, user_id.to_string()); 120 | cookies = cookies.add(really_insecure_login_cookie); 121 | 122 | cookies 123 | } 124 | 125 | pub async fn route_put_user_todos( 126 | State(ref mut state): State, 127 | mut cookies: CookieJar, 128 | Json(todo): Json, 129 | ) -> StdResult, StatusCode> { 130 | let user_id = get_user_id_from_cookie(&mut cookies).map_err(|_| StatusCode::UNAUTHORIZED)?; 131 | 132 | let mut lock = state.write().unwrap(); 133 | let todos = lock.user_todos.get_mut(&user_id).unwrap(); 134 | 135 | todos.push(todo); 136 | let num_todos = todos.len() as u32; 137 | 138 | Ok(Json(num_todos)) 139 | } 140 | 141 | pub async fn route_get_user_todos( 142 | State(ref state): State, 143 | mut cookies: CookieJar, 144 | ) -> StdResult>, StatusCode> { 145 | let user_id = get_user_id_from_cookie(&mut cookies).map_err(|_| StatusCode::UNAUTHORIZED)?; 146 | 147 | let lock = state.read().unwrap(); 148 | let todos = lock.user_todos[&user_id].clone(); 149 | 150 | Ok(Json(todos)) 151 | } 152 | 153 | pub(crate) fn new_app() -> Router { 154 | let state = AppState { 155 | user_todos: HashMap::new(), 156 | }; 157 | let shared_state = Arc::new(RwLock::new(state)); 158 | 159 | Router::new() 160 | .route(&"/login", post(route_post_user_login)) 161 | .route(&"/todo", get(route_get_user_todos)) 162 | .route(&"/todo", put(route_put_user_todos)) 163 | .with_state(shared_state) 164 | } 165 | 166 | #[cfg(test)] 167 | fn new_test_app() -> TestServer { 168 | let app = new_app(); 169 | TestServer::builder() 170 | // Preserve cookies across requests 171 | // for the session cookie to work. 172 | .save_cookies() 173 | .expect_success_by_default() 174 | .mock_transport() 175 | .build(app) 176 | .unwrap() 177 | } 178 | 179 | #[cfg(test)] 180 | mod test_post_login { 181 | use super::*; 182 | 183 | use serde_json::json; 184 | 185 | #[tokio::test] 186 | async fn it_should_create_session_on_login() { 187 | let server = new_test_app(); 188 | 189 | let response = server 190 | .post(&"/login") 191 | .json(&json!({ 192 | "user": "my-login@example.com", 193 | })) 194 | .await; 195 | 196 | let session_cookie = response.cookie(&USER_ID_COOKIE_NAME); 197 | assert_ne!(session_cookie.value(), ""); 198 | } 199 | 200 | #[tokio::test] 201 | async fn it_should_not_login_using_non_email() { 202 | let server = new_test_app(); 203 | 204 | let response = server 205 | .post(&"/login") 206 | .json(&json!({ 207 | "user": "blah blah blah", 208 | })) 209 | .expect_failure() 210 | .await; 211 | 212 | // There should not be a session created. 213 | let cookie = response.maybe_cookie(&USER_ID_COOKIE_NAME); 214 | assert!(cookie.is_none()); 215 | } 216 | } 217 | 218 | #[cfg(test)] 219 | mod test_route_put_user_todos { 220 | use super::*; 221 | 222 | use serde_json::json; 223 | 224 | #[tokio::test] 225 | async fn it_should_not_store_todos_without_login() { 226 | let server = new_test_app(); 227 | 228 | let response = server 229 | .put(&"/todo") 230 | .json(&json!({ 231 | "name": "shopping", 232 | "content": "buy eggs", 233 | })) 234 | .expect_failure() 235 | .await; 236 | 237 | assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED); 238 | } 239 | 240 | #[tokio::test] 241 | async fn it_should_return_number_of_todos_as_more_are_pushed() { 242 | let server = new_test_app(); 243 | 244 | server 245 | .post(&"/login") 246 | .json(&json!({ 247 | "user": "my-login@example.com", 248 | })) 249 | .await; 250 | 251 | let num_todos = server 252 | .put(&"/todo") 253 | .json(&json!({ 254 | "name": "shopping", 255 | "content": "buy eggs", 256 | })) 257 | .await 258 | .json::(); 259 | assert_eq!(num_todos, 1); 260 | 261 | let num_todos = server 262 | .put(&"/todo") 263 | .json(&json!({ 264 | "name": "afternoon", 265 | "content": "buy shoes", 266 | })) 267 | .await 268 | .json::(); 269 | assert_eq!(num_todos, 2); 270 | } 271 | } 272 | 273 | #[cfg(test)] 274 | mod test_route_get_user_todos { 275 | use super::*; 276 | 277 | use serde_json::json; 278 | 279 | #[tokio::test] 280 | async fn it_should_not_return_todos_if_logged_out() { 281 | let server = new_test_app(); 282 | 283 | let response = server 284 | .put(&"/todo") 285 | .json(&json!({ 286 | "name": "shopping", 287 | "content": "buy eggs", 288 | })) 289 | .expect_failure() 290 | .await; 291 | 292 | assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED); 293 | } 294 | 295 | #[tokio::test] 296 | async fn it_should_return_all_todos_when_logged_in() { 297 | let server = new_test_app(); 298 | 299 | server 300 | .post(&"/login") 301 | .json(&json!({ 302 | "user": "my-login@example.com", 303 | })) 304 | .await; 305 | 306 | // Push two todos. 307 | server 308 | .put(&"/todo") 309 | .json(&json!({ 310 | "name": "shopping", 311 | "content": "buy eggs", 312 | })) 313 | .await; 314 | server 315 | .put(&"/todo") 316 | .json(&json!({ 317 | "name": "afternoon", 318 | "content": "buy shoes", 319 | })) 320 | .await; 321 | 322 | // Get all todos out from the server. 323 | let todos = server.get(&"/todo").await.json::>(); 324 | 325 | let expected_todos: Vec = vec![ 326 | Todo { 327 | name: "shopping".to_string(), 328 | content: "buy eggs".to_string(), 329 | }, 330 | Todo { 331 | name: "afternoon".to_string(), 332 | content: "buy shoes".to_string(), 333 | }, 334 | ]; 335 | assert_eq!(todos, expected_todos) 336 | } 337 | } 338 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! Axum Test is a library for writing tests for web servers written using Axum: 3 | //! 4 | //! * You create a [`TestServer`] within a test, 5 | //! * use that to build [`TestRequest`] against your application, 6 | //! * receive back a [`TestResponse`], 7 | //! * then assert the response is how you expect. 8 | //! 9 | //! It includes built in support for serializing and deserializing request and response bodies using Serde, 10 | //! support for cookies and headers, and other common bits you would expect. 11 | //! 12 | //! `TestServer` will pass http requests directly to the handler, 13 | //! or can be run on a random IP / Port address. 14 | //! 15 | //! ## Getting Started 16 | //! 17 | //! Create a [`TestServer`] running your Axum [`Router`](::axum::Router): 18 | //! 19 | //! ```rust 20 | //! # async fn test() -> Result<(), Box> { 21 | //! # 22 | //! use axum::Router; 23 | //! use axum::extract::Json; 24 | //! use axum::routing::put; 25 | //! use axum_test::TestServer; 26 | //! use serde_json::json; 27 | //! use serde_json::Value; 28 | //! 29 | //! async fn route_put_user(Json(user): Json) -> () { 30 | //! // todo 31 | //! } 32 | //! 33 | //! let my_app = Router::new() 34 | //! .route("/users", put(route_put_user)); 35 | //! 36 | //! let server = TestServer::new(my_app)?; 37 | //! # 38 | //! # Ok(()) 39 | //! # } 40 | //! ``` 41 | //! 42 | //! Then make requests against it: 43 | //! 44 | //! ```rust 45 | //! # async fn test() -> Result<(), Box> { 46 | //! # 47 | //! # use axum::Router; 48 | //! # use axum::extract::Json; 49 | //! # use axum::routing::put; 50 | //! # use axum_test::TestServer; 51 | //! # use serde_json::json; 52 | //! # use serde_json::Value; 53 | //! # 54 | //! # async fn put_user(Json(user): Json) -> () {} 55 | //! # 56 | //! # let my_app = Router::new() 57 | //! # .route("/users", put(put_user)); 58 | //! # 59 | //! # let server = TestServer::new(my_app)?; 60 | //! # 61 | //! let response = server.put("/users") 62 | //! .json(&json!({ 63 | //! "username": "Terrance Pencilworth", 64 | //! })) 65 | //! .await; 66 | //! # 67 | //! # Ok(()) 68 | //! # } 69 | //! ``` 70 | //! 71 | 72 | #![allow(clippy::module_inception)] 73 | #![allow(clippy::derivable_impls)] 74 | #![allow(clippy::manual_range_contains)] 75 | #![forbid(unsafe_code)] 76 | #![cfg_attr(docsrs, feature(doc_cfg))] 77 | 78 | pub(crate) mod internals; 79 | 80 | pub mod multipart; 81 | 82 | pub mod transport_layer; 83 | pub mod util; 84 | 85 | mod test_request; 86 | pub use self::test_request::*; 87 | 88 | mod test_response; 89 | pub use self::test_response::*; 90 | 91 | mod test_server_builder; 92 | pub use self::test_server_builder::*; 93 | 94 | mod test_server_config; 95 | pub use self::test_server_config::*; 96 | 97 | mod test_server; 98 | pub use self::test_server::*; 99 | 100 | #[cfg(feature = "ws")] 101 | mod test_web_socket; 102 | #[cfg(feature = "ws")] 103 | pub use self::test_web_socket::*; 104 | #[cfg(feature = "ws")] 105 | pub use tokio_tungstenite::tungstenite::Message as WsMessage; 106 | 107 | mod transport; 108 | pub use self::transport::*; 109 | 110 | pub mod expect_json; 111 | 112 | pub use http; 113 | 114 | #[cfg(test)] 115 | mod testing; 116 | 117 | #[cfg(test)] 118 | mod integrated_test_cookie_saving { 119 | use super::*; 120 | 121 | use axum::Router; 122 | use axum::extract::Request; 123 | use axum::routing::get; 124 | use axum::routing::post; 125 | use axum::routing::put; 126 | use axum_extra::extract::cookie::Cookie as AxumCookie; 127 | use axum_extra::extract::cookie::CookieJar; 128 | use cookie::Cookie; 129 | use cookie::time::OffsetDateTime; 130 | use http_body_util::BodyExt; 131 | use std::time::Duration; 132 | 133 | const TEST_COOKIE_NAME: &'static str = &"test-cookie"; 134 | 135 | async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) { 136 | let cookie = cookies.get(&TEST_COOKIE_NAME); 137 | let cookie_value = cookie 138 | .map(|c| c.value().to_string()) 139 | .unwrap_or_else(|| "cookie-not-found".to_string()); 140 | 141 | (cookies, cookie_value) 142 | } 143 | 144 | async fn put_cookie(mut cookies: CookieJar, request: Request) -> (CookieJar, &'static str) { 145 | let body_bytes = request 146 | .into_body() 147 | .collect() 148 | .await 149 | .expect("Should extract the body") 150 | .to_bytes(); 151 | let body_text: String = String::from_utf8_lossy(&body_bytes).to_string(); 152 | let cookie = AxumCookie::new(TEST_COOKIE_NAME, body_text); 153 | cookies = cookies.add(cookie); 154 | 155 | (cookies, &"done") 156 | } 157 | 158 | async fn post_expire_cookie(mut cookies: CookieJar) -> (CookieJar, &'static str) { 159 | let mut cookie = AxumCookie::new(TEST_COOKIE_NAME, "expired".to_string()); 160 | let expired_time = OffsetDateTime::now_utc() - Duration::from_secs(1); 161 | cookie.set_expires(expired_time); 162 | cookies = cookies.add(cookie); 163 | 164 | (cookies, &"done") 165 | } 166 | 167 | fn new_test_router() -> Router { 168 | Router::new() 169 | .route("/cookie", put(put_cookie)) 170 | .route("/cookie", get(get_cookie)) 171 | .route("/expire", post(post_expire_cookie)) 172 | } 173 | 174 | #[tokio::test] 175 | async fn it_should_not_pass_cookies_created_back_up_to_server_by_default() { 176 | // Run the server. 177 | let server = TestServer::new(new_test_router()).expect("Should create test server"); 178 | 179 | // Create a cookie. 180 | server.put(&"/cookie").text(&"new-cookie").await; 181 | 182 | // Check it comes back. 183 | let response_text = server.get(&"/cookie").await.text(); 184 | 185 | assert_eq!(response_text, "cookie-not-found"); 186 | } 187 | 188 | #[tokio::test] 189 | async fn it_should_not_pass_cookies_created_back_up_to_server_when_turned_off() { 190 | // Run the server. 191 | let server = TestServer::builder() 192 | .do_not_save_cookies() 193 | .build(new_test_router()) 194 | .expect("Should create test server"); 195 | 196 | // Create a cookie. 197 | server.put(&"/cookie").text(&"new-cookie").await; 198 | 199 | // Check it comes back. 200 | let response_text = server.get(&"/cookie").await.text(); 201 | 202 | assert_eq!(response_text, "cookie-not-found"); 203 | } 204 | 205 | #[tokio::test] 206 | async fn it_should_pass_cookies_created_back_up_to_server_automatically() { 207 | // Run the server. 208 | let server = TestServer::builder() 209 | .save_cookies() 210 | .build(new_test_router()) 211 | .expect("Should create test server"); 212 | 213 | // Create a cookie. 214 | server.put(&"/cookie").text(&"cookie-found!").await; 215 | 216 | // Check it comes back. 217 | let response_text = server.get(&"/cookie").await.text(); 218 | 219 | assert_eq!(response_text, "cookie-found!"); 220 | } 221 | 222 | #[tokio::test] 223 | async fn it_should_pass_cookies_created_back_up_to_server_when_turned_on_for_request() { 224 | // Run the server. 225 | let server = TestServer::builder() 226 | .do_not_save_cookies() // it's off by default! 227 | .build(new_test_router()) 228 | .expect("Should create test server"); 229 | 230 | // Create a cookie. 231 | server 232 | .put(&"/cookie") 233 | .text(&"cookie-found!") 234 | .save_cookies() 235 | .await; 236 | 237 | // Check it comes back. 238 | let response_text = server.get(&"/cookie").await.text(); 239 | 240 | assert_eq!(response_text, "cookie-found!"); 241 | } 242 | 243 | #[tokio::test] 244 | async fn it_should_wipe_cookies_cleared_by_request() { 245 | // Run the server. 246 | let server = TestServer::builder() 247 | .do_not_save_cookies() // it's off by default! 248 | .build(new_test_router()) 249 | .expect("Should create test server"); 250 | 251 | // Create a cookie. 252 | server 253 | .put(&"/cookie") 254 | .text(&"cookie-found!") 255 | .save_cookies() 256 | .await; 257 | 258 | // Check it comes back. 259 | let response_text = server.get(&"/cookie").clear_cookies().await.text(); 260 | 261 | assert_eq!(response_text, "cookie-not-found"); 262 | } 263 | 264 | #[tokio::test] 265 | async fn it_should_wipe_cookies_cleared_by_test_server() { 266 | // Run the server. 267 | let mut server = TestServer::builder() 268 | .do_not_save_cookies() // it's off by default! 269 | .build(new_test_router()) 270 | .expect("Should create test server"); 271 | 272 | // Create a cookie. 273 | server 274 | .put(&"/cookie") 275 | .text(&"cookie-found!") 276 | .save_cookies() 277 | .await; 278 | 279 | server.clear_cookies(); 280 | 281 | // Check it comes back. 282 | let response_text = server.get(&"/cookie").await.text(); 283 | 284 | assert_eq!(response_text, "cookie-not-found"); 285 | } 286 | 287 | #[tokio::test] 288 | async fn it_should_send_cookies_added_to_request() { 289 | // Run the server. 290 | let server = TestServer::builder() 291 | .do_not_save_cookies() // it's off by default! 292 | .build(new_test_router()) 293 | .expect("Should create test server"); 294 | 295 | // Check it comes back. 296 | let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie"); 297 | 298 | let response_text = server.get(&"/cookie").add_cookie(cookie).await.text(); 299 | 300 | assert_eq!(response_text, "my-custom-cookie"); 301 | } 302 | 303 | #[tokio::test] 304 | async fn it_should_send_cookies_added_to_test_server() { 305 | // Run the server. 306 | let mut server = TestServer::builder() 307 | .do_not_save_cookies() // it's off by default! 308 | .build(new_test_router()) 309 | .expect("Should create test server"); 310 | 311 | // Check it comes back. 312 | let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie"); 313 | server.add_cookie(cookie); 314 | 315 | let response_text = server.get(&"/cookie").await.text(); 316 | 317 | assert_eq!(response_text, "my-custom-cookie"); 318 | } 319 | 320 | #[tokio::test] 321 | async fn it_should_remove_expired_cookies_from_later_requests() { 322 | // Run the server. 323 | let mut server = TestServer::new(new_test_router()).expect("Should create test server"); 324 | server.save_cookies(); 325 | 326 | // Create a cookie. 327 | server.put(&"/cookie").text(&"cookie-found!").await; 328 | 329 | // Check it comes back. 330 | let response_text = server.get(&"/cookie").await.text(); 331 | assert_eq!(response_text, "cookie-found!"); 332 | 333 | server.post(&"/expire").await; 334 | 335 | // Then expire the cookie. 336 | let found_cookie = server.post(&"/expire").await.maybe_cookie(TEST_COOKIE_NAME); 337 | assert!(found_cookie.is_some()); 338 | 339 | // It's no longer found 340 | let response_text = server.get(&"/cookie").await.text(); 341 | assert_eq!(response_text, "cookie-not-found"); 342 | } 343 | } 344 | 345 | #[cfg(feature = "typed-routing")] 346 | #[cfg(test)] 347 | mod integrated_test_typed_routing_and_query { 348 | use super::*; 349 | 350 | use axum::Router; 351 | use axum::extract::Query; 352 | use axum_extra::routing::RouterExt; 353 | use axum_extra::routing::TypedPath; 354 | use serde::Deserialize; 355 | use serde::Serialize; 356 | 357 | #[derive(TypedPath, Deserialize)] 358 | #[typed_path("/path-query/{id}")] 359 | struct TestingPathQuery { 360 | id: u32, 361 | } 362 | 363 | #[derive(Serialize, Deserialize)] 364 | struct QueryParams { 365 | param: String, 366 | other: Option, 367 | } 368 | 369 | async fn route_get_with_param( 370 | TestingPathQuery { id }: TestingPathQuery, 371 | Query(params): Query, 372 | ) -> String { 373 | let query = params.param; 374 | if let Some(other) = params.other { 375 | format!("get {id}, {query}&{other}") 376 | } else { 377 | format!("get {id}, {query}") 378 | } 379 | } 380 | 381 | fn new_app() -> Router { 382 | Router::new().typed_get(route_get_with_param) 383 | } 384 | 385 | #[tokio::test] 386 | async fn it_should_send_typed_get_with_query_params() { 387 | let server = TestServer::new(new_app()).unwrap(); 388 | let path = TestingPathQuery { id: 123 }.with_query_params(QueryParams { 389 | param: "with-typed-query".to_string(), 390 | other: None, 391 | }); 392 | 393 | server 394 | .typed_get(&path) 395 | .expect_success() 396 | .await 397 | .assert_text("get 123, with-typed-query"); 398 | } 399 | 400 | #[tokio::test] 401 | async fn it_should_send_typed_get_with_added_query_param() { 402 | let server = TestServer::new(new_app()).unwrap(); 403 | let path = TestingPathQuery { id: 123 }; 404 | 405 | server 406 | .typed_get(&path) 407 | .add_query_param("param", "with-added-query") 408 | .expect_success() 409 | .await 410 | .assert_text("get 123, with-added-query"); 411 | } 412 | 413 | #[tokio::test] 414 | async fn it_should_send_both_typed_and_added_query() { 415 | let server = TestServer::new(new_app()).unwrap(); 416 | let path = TestingPathQuery { id: 123 }.with_query_params(QueryParams { 417 | param: "with-typed-query".to_string(), 418 | other: None, 419 | }); 420 | 421 | server 422 | .typed_get(&path) 423 | .add_query_param("other", "with-added-query") 424 | .expect_success() 425 | .await 426 | .assert_text("get 123, with-typed-query&with-added-query"); 427 | } 428 | 429 | #[tokio::test] 430 | async fn it_should_send_replaced_query_when_cleared() { 431 | let server = TestServer::new(new_app()).unwrap(); 432 | let path = TestingPathQuery { id: 123 }.with_query_params(QueryParams { 433 | param: "with-typed-query".to_string(), 434 | other: Some("with-typed-other".to_string()), 435 | }); 436 | 437 | server 438 | .typed_get(&path) 439 | .clear_query_params() 440 | .add_query_param("param", "with-added-query") 441 | .expect_success() 442 | .await 443 | .assert_text("get 123, with-added-query"); 444 | } 445 | } 446 | -------------------------------------------------------------------------------- /src/test_web_socket.rs: -------------------------------------------------------------------------------- 1 | use crate::WsMessage; 2 | use anyhow::Context; 3 | use anyhow::Result; 4 | use anyhow::anyhow; 5 | use bytes::Bytes; 6 | use futures_util::sink::SinkExt; 7 | use futures_util::stream::StreamExt; 8 | use hyper::upgrade::Upgraded; 9 | use hyper_util::rt::TokioIo; 10 | use serde::Serialize; 11 | use serde::de::DeserializeOwned; 12 | use serde_json::Value; 13 | use std::fmt::Debug; 14 | use std::fmt::Display; 15 | use tokio_tungstenite::WebSocketStream; 16 | use tokio_tungstenite::tungstenite::protocol::Role; 17 | 18 | #[cfg(feature = "pretty-assertions")] 19 | use pretty_assertions::assert_eq; 20 | 21 | #[cfg(not(feature = "old-json-diff"))] 22 | use expect_json::expect; 23 | #[cfg(not(feature = "old-json-diff"))] 24 | use expect_json::expect_json_eq; 25 | 26 | #[derive(Debug)] 27 | pub struct TestWebSocket { 28 | stream: WebSocketStream>, 29 | } 30 | 31 | impl TestWebSocket { 32 | pub(crate) async fn new(upgraded: Upgraded) -> Self { 33 | let upgraded_io = TokioIo::new(upgraded); 34 | let stream = WebSocketStream::from_raw_socket(upgraded_io, Role::Client, None).await; 35 | 36 | Self { stream } 37 | } 38 | 39 | pub async fn close(mut self) { 40 | self.stream 41 | .close(None) 42 | .await 43 | .expect("Failed to close WebSocket stream"); 44 | } 45 | 46 | pub async fn send_text(&mut self, raw_text: T) 47 | where 48 | T: Display, 49 | { 50 | let text = raw_text.to_string(); 51 | self.send_message(WsMessage::Text(text.into())).await; 52 | } 53 | 54 | pub async fn send_json(&mut self, body: &J) 55 | where 56 | J: ?Sized + Serialize, 57 | { 58 | let raw_json = 59 | ::serde_json::to_string(body).expect("It should serialize the content into Json"); 60 | 61 | self.send_message(WsMessage::Text(raw_json.into())).await; 62 | } 63 | 64 | #[cfg(feature = "yaml")] 65 | pub async fn send_yaml(&mut self, body: &Y) 66 | where 67 | Y: ?Sized + Serialize, 68 | { 69 | let raw_yaml = 70 | ::serde_yaml::to_string(body).expect("It should serialize the content into Yaml"); 71 | 72 | self.send_message(WsMessage::Text(raw_yaml.into())).await; 73 | } 74 | 75 | #[cfg(feature = "msgpack")] 76 | pub async fn send_msgpack(&mut self, body: &M) 77 | where 78 | M: ?Sized + Serialize, 79 | { 80 | let body_bytes = 81 | ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack"); 82 | 83 | self.send_message(WsMessage::Binary(body_bytes.into())) 84 | .await; 85 | } 86 | 87 | pub async fn send_message(&mut self, message: WsMessage) { 88 | self.stream.send(message).await.unwrap(); 89 | } 90 | 91 | #[must_use] 92 | pub async fn receive_text(&mut self) -> String { 93 | let message = self.receive_message().await; 94 | 95 | message_to_text(message) 96 | .context("Failed to read message as a String") 97 | .unwrap() 98 | } 99 | 100 | #[must_use] 101 | pub async fn receive_json(&mut self) -> T 102 | where 103 | T: DeserializeOwned, 104 | { 105 | let bytes = self.receive_bytes().await; 106 | serde_json::from_slice::(&bytes) 107 | .context("Failed to deserialize message as Json") 108 | .unwrap() 109 | } 110 | 111 | #[cfg(feature = "yaml")] 112 | #[must_use] 113 | pub async fn receive_yaml(&mut self) -> T 114 | where 115 | T: DeserializeOwned, 116 | { 117 | let bytes = self.receive_bytes().await; 118 | serde_yaml::from_slice::(&bytes) 119 | .context("Failed to deserialize message as Yaml") 120 | .unwrap() 121 | } 122 | 123 | #[cfg(feature = "msgpack")] 124 | #[must_use] 125 | pub async fn receive_msgpack(&mut self) -> T 126 | where 127 | T: DeserializeOwned, 128 | { 129 | let received_bytes = self.receive_bytes().await; 130 | rmp_serde::from_slice::(&received_bytes) 131 | .context("Failed to deserializing message as MsgPack") 132 | .unwrap() 133 | } 134 | 135 | #[must_use] 136 | pub async fn receive_bytes(&mut self) -> Bytes { 137 | let message = self.receive_message().await; 138 | 139 | message_to_bytes(message) 140 | .context("Failed to read message as a Bytes") 141 | .unwrap() 142 | } 143 | 144 | #[must_use] 145 | pub async fn receive_message(&mut self) -> WsMessage { 146 | self.maybe_receive_message() 147 | .await 148 | .expect("No message found on WebSocket stream") 149 | } 150 | 151 | #[must_use] 152 | async fn maybe_receive_message(&mut self) -> Option { 153 | let maybe_message = self.stream.next().await; 154 | 155 | match maybe_message { 156 | None => None, 157 | Some(message_result) => { 158 | let message = 159 | message_result.expect("Failed to receive message from WebSocket stream"); 160 | Some(message) 161 | } 162 | } 163 | } 164 | 165 | pub async fn assert_receive_json(&mut self, expected: &T) 166 | where 167 | T: Serialize + DeserializeOwned + PartialEq + Debug, 168 | { 169 | let received = self.receive_json::().await; 170 | 171 | #[cfg(feature = "old-json-diff")] 172 | { 173 | assert_eq!(*expected, received); 174 | } 175 | 176 | #[cfg(not(feature = "old-json-diff"))] 177 | { 178 | if *expected != received { 179 | if let Err(error) = expect_json_eq(&received, &expected) { 180 | panic!( 181 | " 182 | {error} 183 | ", 184 | ); 185 | } 186 | } 187 | } 188 | } 189 | 190 | pub async fn assert_receive_json_contains(&mut self, expected: &T) 191 | where 192 | T: Serialize, 193 | { 194 | let received = self.receive_json::().await; 195 | 196 | #[cfg(feature = "old-json-diff")] 197 | { 198 | assert_json_diff::assert_json_include!(actual: received, expected: expected); 199 | } 200 | 201 | #[cfg(not(feature = "old-json-diff"))] 202 | { 203 | let expected_value = serde_json::to_value(expected).unwrap(); 204 | let result = expect_json_eq( 205 | &received, 206 | &expect::object().propagated_contains(expected_value), 207 | ); 208 | if let Err(error) = result { 209 | panic!( 210 | " 211 | {error} 212 | ", 213 | ); 214 | } 215 | } 216 | } 217 | 218 | pub async fn assert_receive_text(&mut self, expected: C) 219 | where 220 | C: AsRef, 221 | { 222 | let expected_contents = expected.as_ref(); 223 | assert_eq!(expected_contents, &self.receive_text().await); 224 | } 225 | 226 | pub async fn assert_receive_text_contains(&mut self, expected: C) 227 | where 228 | C: AsRef, 229 | { 230 | let expected_contents = expected.as_ref(); 231 | let received = self.receive_text().await; 232 | let is_contained = received.contains(expected_contents); 233 | 234 | assert!( 235 | is_contained, 236 | "Failed to find '{expected_contents}', received '{received}'" 237 | ); 238 | } 239 | 240 | #[cfg(feature = "yaml")] 241 | pub async fn assert_receive_yaml(&mut self, expected: &T) 242 | where 243 | T: DeserializeOwned + PartialEq + Debug, 244 | { 245 | assert_eq!(*expected, self.receive_yaml::().await); 246 | } 247 | 248 | #[cfg(feature = "msgpack")] 249 | pub async fn assert_receive_msgpack(&mut self, expected: &T) 250 | where 251 | T: DeserializeOwned + PartialEq + Debug, 252 | { 253 | assert_eq!(*expected, self.receive_msgpack::().await); 254 | } 255 | } 256 | 257 | fn message_to_text(message: WsMessage) -> Result { 258 | let text = match message { 259 | WsMessage::Text(text) => text.to_string(), 260 | WsMessage::Binary(data) => { 261 | String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())? 262 | } 263 | WsMessage::Ping(data) => { 264 | String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())? 265 | } 266 | WsMessage::Pong(data) => { 267 | String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())? 268 | } 269 | WsMessage::Close(None) => String::new(), 270 | WsMessage::Close(Some(frame)) => frame.reason.to_string(), 271 | WsMessage::Frame(_) => { 272 | return Err(anyhow!( 273 | "Unexpected Frame, did not expect Frame message whilst reading" 274 | )); 275 | } 276 | }; 277 | 278 | Ok(text) 279 | } 280 | 281 | fn message_to_bytes(message: WsMessage) -> Result { 282 | let bytes = match message { 283 | WsMessage::Text(string) => string.into(), 284 | WsMessage::Binary(data) => data, 285 | WsMessage::Ping(data) => data, 286 | WsMessage::Pong(data) => data, 287 | WsMessage::Close(None) => Bytes::new(), 288 | WsMessage::Close(Some(frame)) => frame.reason.into(), 289 | WsMessage::Frame(_) => { 290 | return Err(anyhow!( 291 | "Unexpected Frame, did not expect Frame message whilst reading" 292 | )); 293 | } 294 | }; 295 | 296 | Ok(bytes) 297 | } 298 | 299 | #[cfg(test)] 300 | mod test_assert_receive_text { 301 | use crate::TestServer; 302 | 303 | use axum::Router; 304 | use axum::extract::WebSocketUpgrade; 305 | use axum::extract::ws::Message; 306 | use axum::extract::ws::WebSocket; 307 | use axum::response::Response; 308 | use axum::routing::get; 309 | 310 | fn new_test_app() -> TestServer { 311 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 312 | async fn handle_ping_pong(mut socket: WebSocket) { 313 | while let Some(maybe_message) = socket.recv().await { 314 | let message_text = maybe_message.unwrap().into_text().unwrap(); 315 | 316 | let encoded_text = format!("Text: {message_text}").try_into().unwrap(); 317 | let encoded_data = format!("Binary: {message_text}").into_bytes().into(); 318 | 319 | socket.send(Message::Text(encoded_text)).await.unwrap(); 320 | socket.send(Message::Binary(encoded_data)).await.unwrap(); 321 | } 322 | } 323 | 324 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 325 | } 326 | 327 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 328 | TestServer::builder().http_transport().build(app).unwrap() 329 | } 330 | 331 | #[tokio::test] 332 | async fn it_should_ping_pong_text_in_text_and_binary() { 333 | let server = new_test_app(); 334 | 335 | let mut websocket = server 336 | .get_websocket(&"/ws-ping-pong") 337 | .await 338 | .into_websocket() 339 | .await; 340 | 341 | websocket.send_text("Hello World!").await; 342 | 343 | websocket.assert_receive_text("Text: Hello World!").await; 344 | websocket.assert_receive_text("Binary: Hello World!").await; 345 | } 346 | 347 | #[tokio::test] 348 | async fn it_should_ping_pong_large_text_blobs() { 349 | const LARGE_BLOB_SIZE: usize = 16777200; // Max websocket size (16mb) - 16 bytes for the 'Text: ' in the reply. 350 | let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::(); 351 | 352 | let server = new_test_app(); 353 | let mut websocket = server 354 | .get_websocket(&"/ws-ping-pong") 355 | .await 356 | .into_websocket() 357 | .await; 358 | 359 | websocket.send_text(&large_blob).await; 360 | 361 | websocket 362 | .assert_receive_text(format!("Text: {large_blob}")) 363 | .await; 364 | websocket 365 | .assert_receive_text(format!("Binary: {large_blob}")) 366 | .await; 367 | } 368 | 369 | #[tokio::test] 370 | #[should_panic] 371 | async fn it_should_not_match_partial_text_match() { 372 | let server = new_test_app(); 373 | 374 | let mut websocket = server 375 | .get_websocket(&"/ws-ping-pong") 376 | .await 377 | .into_websocket() 378 | .await; 379 | 380 | websocket.send_text("Hello World!").await; 381 | websocket.assert_receive_text("Hello World!").await; 382 | } 383 | 384 | #[tokio::test] 385 | #[should_panic] 386 | async fn it_should_not_match_different_text() { 387 | let server = new_test_app(); 388 | 389 | let mut websocket = server 390 | .get_websocket(&"/ws-ping-pong") 391 | .await 392 | .into_websocket() 393 | .await; 394 | 395 | websocket.send_text("Hello World!").await; 396 | websocket.assert_receive_text("🦊").await; 397 | } 398 | } 399 | 400 | #[cfg(test)] 401 | mod test_assert_receive_text_contains { 402 | use crate::TestServer; 403 | use axum::Router; 404 | use axum::extract::WebSocketUpgrade; 405 | use axum::extract::ws::Message; 406 | use axum::extract::ws::WebSocket; 407 | use axum::response::Response; 408 | use axum::routing::get; 409 | 410 | fn new_test_app() -> TestServer { 411 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 412 | async fn handle_ping_pong(mut socket: WebSocket) { 413 | while let Some(maybe_message) = socket.recv().await { 414 | let message_text = maybe_message.unwrap().into_text().unwrap(); 415 | let encoded_text = format!("Text: {message_text}").try_into().unwrap(); 416 | 417 | socket.send(Message::Text(encoded_text)).await.unwrap(); 418 | } 419 | } 420 | 421 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 422 | } 423 | 424 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 425 | TestServer::builder().http_transport().build(app).unwrap() 426 | } 427 | 428 | #[tokio::test] 429 | async fn it_should_assert_whole_text_match() { 430 | let server = new_test_app(); 431 | 432 | let mut websocket = server 433 | .get_websocket(&"/ws-ping-pong") 434 | .await 435 | .into_websocket() 436 | .await; 437 | 438 | websocket.send_text("Hello World!").await; 439 | websocket 440 | .assert_receive_text_contains("Text: Hello World!") 441 | .await; 442 | } 443 | 444 | #[tokio::test] 445 | async fn it_should_assert_partial_text_match() { 446 | let server = new_test_app(); 447 | 448 | let mut websocket = server 449 | .get_websocket(&"/ws-ping-pong") 450 | .await 451 | .into_websocket() 452 | .await; 453 | 454 | websocket.send_text("Hello World!").await; 455 | websocket.assert_receive_text_contains("Hello World!").await; 456 | } 457 | 458 | #[tokio::test] 459 | #[should_panic] 460 | async fn it_should_not_match_different_text() { 461 | let server = new_test_app(); 462 | 463 | let mut websocket = server 464 | .get_websocket(&"/ws-ping-pong") 465 | .await 466 | .into_websocket() 467 | .await; 468 | 469 | websocket.send_text("Hello World!").await; 470 | websocket.assert_receive_text_contains("🦊").await; 471 | } 472 | } 473 | 474 | #[cfg(test)] 475 | mod test_assert_receive_json { 476 | use crate::TestServer; 477 | use axum::Router; 478 | use axum::extract::WebSocketUpgrade; 479 | use axum::extract::ws::Message; 480 | use axum::extract::ws::WebSocket; 481 | use axum::response::Response; 482 | use axum::routing::get; 483 | use serde_json::Value; 484 | use serde_json::json; 485 | 486 | #[cfg(not(feature = "old-json-diff"))] 487 | use crate::testing::ExpectStrMinLen; 488 | #[cfg(not(feature = "old-json-diff"))] 489 | use expect_json::expect; 490 | 491 | fn new_test_app() -> TestServer { 492 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 493 | async fn handle_ping_pong(mut socket: WebSocket) { 494 | while let Some(maybe_message) = socket.recv().await { 495 | let message_text = maybe_message.unwrap().into_text().unwrap(); 496 | let decoded = serde_json::from_str::(&message_text).unwrap(); 497 | 498 | let encoded_text = serde_json::to_string(&json!({ 499 | "format": "text", 500 | "message": decoded 501 | })) 502 | .unwrap() 503 | .try_into() 504 | .unwrap(); 505 | let encoded_data = serde_json::to_vec(&json!({ 506 | "format": "binary", 507 | "message": decoded 508 | })) 509 | .unwrap() 510 | .into(); 511 | 512 | socket.send(Message::Text(encoded_text)).await.unwrap(); 513 | socket.send(Message::Binary(encoded_data)).await.unwrap(); 514 | } 515 | } 516 | 517 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 518 | } 519 | 520 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 521 | TestServer::builder().http_transport().build(app).unwrap() 522 | } 523 | 524 | #[tokio::test] 525 | async fn it_should_ping_pong_json_in_text_and_binary() { 526 | let server = new_test_app(); 527 | 528 | let mut websocket = server 529 | .get_websocket(&"/ws-ping-pong") 530 | .await 531 | .into_websocket() 532 | .await; 533 | 534 | websocket 535 | .send_json(&json!({ 536 | "hello": "world", 537 | "numbers": [1, 2, 3], 538 | })) 539 | .await; 540 | 541 | // Once for text 542 | websocket 543 | .assert_receive_json(&json!({ 544 | "format": "text", 545 | "message": { 546 | "hello": "world", 547 | "numbers": [1, 2, 3], 548 | }, 549 | })) 550 | .await; 551 | 552 | // Again for binary 553 | websocket 554 | .assert_receive_json(&json!({ 555 | "format": "binary", 556 | "message": { 557 | "hello": "world", 558 | "numbers": [1, 2, 3], 559 | }, 560 | })) 561 | .await; 562 | } 563 | 564 | #[cfg(not(feature = "old-json-diff"))] 565 | #[tokio::test] 566 | async fn it_should_work_with_custom_expect_op() { 567 | let server = new_test_app(); 568 | let mut websocket = server 569 | .get_websocket(&"/ws-ping-pong") 570 | .await 571 | .into_websocket() 572 | .await; 573 | 574 | websocket 575 | .send_json(&json!({ 576 | "hello": "world", 577 | "numbers": [1, 2, 3], 578 | })) 579 | .await; 580 | 581 | // Once for text 582 | websocket 583 | .assert_receive_json(&json!({ 584 | "format": "text", 585 | "message": { 586 | "hello": ExpectStrMinLen { min: 3 }, 587 | "numbers": expect::array().len(3).all(expect::integer()), 588 | }, 589 | })) 590 | .await; 591 | 592 | // Again for binary 593 | websocket 594 | .assert_receive_json(&json!({ 595 | "format": "binary", 596 | "message": { 597 | "hello": ExpectStrMinLen { min: 3 }, 598 | "numbers": expect::array().len(3).all(expect::integer()), 599 | }, 600 | })) 601 | .await; 602 | } 603 | 604 | #[cfg(not(feature = "old-json-diff"))] 605 | #[tokio::test] 606 | #[should_panic] 607 | async fn it_should_panic_if_custom_expect_op_fails() { 608 | let server = new_test_app(); 609 | let mut websocket = server 610 | .get_websocket(&"/ws-ping-pong") 611 | .await 612 | .into_websocket() 613 | .await; 614 | 615 | websocket 616 | .send_json(&json!({ 617 | "hello": "world", 618 | "numbers": [1, 2, 3], 619 | })) 620 | .await; 621 | 622 | // Once for text 623 | websocket 624 | .assert_receive_json(&json!({ 625 | "format": "text", 626 | "message": { 627 | "hello": ExpectStrMinLen { min: 10 }, 628 | "numbers": expect::array().len(3).all(expect::integer()), 629 | }, 630 | })) 631 | .await; 632 | } 633 | } 634 | 635 | #[cfg(test)] 636 | mod test_assert_receive_json_contains { 637 | use crate::TestServer; 638 | use axum::Router; 639 | use axum::extract::WebSocketUpgrade; 640 | use axum::extract::ws::Message; 641 | use axum::extract::ws::WebSocket; 642 | use axum::response::Response; 643 | use axum::routing::get; 644 | use serde_json::Value; 645 | use serde_json::json; 646 | 647 | fn new_test_app() -> TestServer { 648 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 649 | async fn handle_ping_pong(mut socket: WebSocket) { 650 | while let Some(maybe_message) = socket.recv().await { 651 | let message_text = maybe_message.unwrap().into_text().unwrap(); 652 | let decoded = serde_json::from_str::(&message_text).unwrap(); 653 | 654 | let encoded_text = serde_json::to_string(&json!({ 655 | "format": "text", 656 | "message": decoded 657 | })) 658 | .unwrap() 659 | .try_into() 660 | .unwrap(); 661 | let encoded_data = serde_json::to_vec(&json!({ 662 | "format": "binary", 663 | "message": decoded 664 | })) 665 | .unwrap() 666 | .into(); 667 | 668 | socket.send(Message::Text(encoded_text)).await.unwrap(); 669 | socket.send(Message::Binary(encoded_data)).await.unwrap(); 670 | } 671 | } 672 | 673 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 674 | } 675 | 676 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 677 | TestServer::builder().http_transport().build(app).unwrap() 678 | } 679 | 680 | #[tokio::test] 681 | async fn it_should_ping_pong_json_in_text_and_binary_with_root_content_missing_in_contains() { 682 | let server = new_test_app(); 683 | 684 | let mut websocket = server 685 | .get_websocket(&"/ws-ping-pong") 686 | .await 687 | .into_websocket() 688 | .await; 689 | 690 | websocket 691 | .send_json(&json!({ 692 | "hello": "world", 693 | "numbers": [1, 2, 3], 694 | })) 695 | .await; 696 | 697 | // Once for text 698 | websocket 699 | .assert_receive_json_contains(&json!({ 700 | // "format" is missing here 701 | "message": { 702 | "hello": "world", 703 | "numbers": [1, 2, 3], 704 | }, 705 | })) 706 | .await; 707 | 708 | // Again for binary 709 | websocket 710 | .assert_receive_json_contains(&json!({ 711 | "format": "binary", 712 | // "message" is missing here 713 | })) 714 | .await; 715 | } 716 | 717 | #[tokio::test] 718 | async fn it_should_ping_pong_json_in_text_and_binary_with_nested_content_missing_in_contains() { 719 | let server = new_test_app(); 720 | 721 | let mut websocket = server 722 | .get_websocket(&"/ws-ping-pong") 723 | .await 724 | .into_websocket() 725 | .await; 726 | 727 | websocket 728 | .send_json(&json!({ 729 | "hello": "world", 730 | "numbers": [1, 2, 3], 731 | })) 732 | .await; 733 | 734 | // Once for text 735 | websocket 736 | .assert_receive_json_contains(&json!({ 737 | "format": "text", 738 | "message": { 739 | // "hello" is missing here 740 | "numbers": [1, 2, 3], 741 | }, 742 | })) 743 | .await; 744 | 745 | // Again for binary 746 | websocket 747 | .assert_receive_json_contains(&json!({ 748 | "format": "binary", 749 | "message": { 750 | "hello": "world", 751 | // "numbers" is missing here 752 | }, 753 | })) 754 | .await; 755 | } 756 | } 757 | 758 | #[cfg(feature = "yaml")] 759 | #[cfg(test)] 760 | mod test_assert_receive_yaml { 761 | use crate::TestServer; 762 | 763 | use axum::Router; 764 | use axum::extract::WebSocketUpgrade; 765 | use axum::extract::ws::Message; 766 | use axum::extract::ws::WebSocket; 767 | use axum::response::Response; 768 | use axum::routing::get; 769 | use serde_json::Value; 770 | use serde_json::json; 771 | 772 | fn new_test_app() -> TestServer { 773 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 774 | async fn handle_ping_pong(mut socket: WebSocket) { 775 | while let Some(maybe_message) = socket.recv().await { 776 | let message_text = maybe_message.unwrap().into_text().unwrap(); 777 | let decoded = serde_yaml::from_str::(&message_text).unwrap(); 778 | 779 | let encoded_text = serde_yaml::to_string(&json!({ 780 | "format": "text", 781 | "message": decoded 782 | })) 783 | .unwrap() 784 | .try_into() 785 | .unwrap(); 786 | let encoded_data = serde_yaml::to_string(&json!({ 787 | "format": "binary", 788 | "message": decoded 789 | })) 790 | .unwrap() 791 | .into(); 792 | 793 | socket.send(Message::Text(encoded_text)).await.unwrap(); 794 | socket.send(Message::Binary(encoded_data)).await.unwrap(); 795 | } 796 | } 797 | 798 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 799 | } 800 | 801 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 802 | TestServer::builder().http_transport().build(app).unwrap() 803 | } 804 | 805 | #[tokio::test] 806 | async fn it_should_ping_pong_yaml_in_text_and_binary() { 807 | let server = new_test_app(); 808 | 809 | let mut websocket = server 810 | .get_websocket(&"/ws-ping-pong") 811 | .await 812 | .into_websocket() 813 | .await; 814 | 815 | websocket 816 | .send_json(&json!({ 817 | "hello": "world", 818 | "numbers": [1, 2, 3], 819 | })) 820 | .await; 821 | 822 | // Once for text 823 | websocket 824 | .assert_receive_yaml(&json!({ 825 | "format": "text", 826 | "message": { 827 | "hello": "world", 828 | "numbers": [1, 2, 3], 829 | }, 830 | })) 831 | .await; 832 | 833 | // Again for binary 834 | websocket 835 | .assert_receive_yaml(&json!({ 836 | "format": "binary", 837 | "message": { 838 | "hello": "world", 839 | "numbers": [1, 2, 3], 840 | }, 841 | })) 842 | .await; 843 | } 844 | } 845 | 846 | #[cfg(feature = "msgpack")] 847 | #[cfg(test)] 848 | mod test_assert_receive_msgpack { 849 | use crate::TestServer; 850 | 851 | use axum::Router; 852 | use axum::extract::WebSocketUpgrade; 853 | use axum::extract::ws::Message; 854 | use axum::extract::ws::WebSocket; 855 | use axum::response::Response; 856 | use axum::routing::get; 857 | use serde_json::Value; 858 | use serde_json::json; 859 | 860 | fn new_test_app() -> TestServer { 861 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 862 | async fn handle_ping_pong(mut socket: WebSocket) { 863 | while let Some(maybe_message) = socket.recv().await { 864 | let message_data = maybe_message.unwrap().into_data(); 865 | let decoded = rmp_serde::from_slice::(&message_data).unwrap(); 866 | 867 | let encoded_data = ::rmp_serde::to_vec(&json!({ 868 | "format": "binary", 869 | "message": decoded 870 | })) 871 | .unwrap() 872 | .into(); 873 | 874 | socket.send(Message::Binary(encoded_data)).await.unwrap(); 875 | } 876 | } 877 | 878 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 879 | } 880 | 881 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 882 | TestServer::builder().http_transport().build(app).unwrap() 883 | } 884 | 885 | #[tokio::test] 886 | async fn it_should_ping_pong_msgpack_in_binary() { 887 | let server = new_test_app(); 888 | 889 | let mut websocket = server 890 | .get_websocket(&"/ws-ping-pong") 891 | .await 892 | .into_websocket() 893 | .await; 894 | 895 | websocket 896 | .send_msgpack(&json!({ 897 | "hello": "world", 898 | "numbers": [1, 2, 3], 899 | })) 900 | .await; 901 | 902 | websocket 903 | .assert_receive_msgpack(&json!({ 904 | "format": "binary", 905 | "message": { 906 | "hello": "world", 907 | "numbers": [1, 2, 3], 908 | }, 909 | })) 910 | .await; 911 | } 912 | } 913 | --------------------------------------------------------------------------------