├── .github └── pull_request_template.md ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── docs.sh ├── examples ├── example-shuttle │ ├── README.md │ └── main.rs ├── example-todo │ ├── README.md │ └── main.rs ├── example-websocket-chat │ ├── README.md │ └── main.rs └── example-websocket-ping-pong │ ├── README.md │ └── main.rs ├── files ├── example.json ├── example.txt └── example.yaml ├── rust-toolchain ├── src ├── expect_json.rs ├── internals │ ├── debug_response_body.rs │ ├── expected_state.rs │ ├── format_status_code_range.rs │ ├── mod.rs │ ├── query_params_store.rs │ ├── request_path_formatter.rs │ ├── starting_tcp_setup.rs │ ├── status_code_formatter.rs │ ├── transport_layer │ │ ├── http_transport_layer.rs │ │ ├── mock_transport_layer.rs │ │ └── mod.rs │ ├── try_into_range_bounds.rs │ ├── websockets │ │ ├── mod.rs │ │ ├── test_response_websocket.rs │ │ └── ws_key_generator.rs │ └── with_this_mut.rs ├── lib.rs ├── multipart │ ├── mod.rs │ ├── multipart_form.rs │ └── part.rs ├── test_request.rs ├── test_request │ └── test_request_config.rs ├── test_response.rs ├── test_server.rs ├── test_server │ └── server_shared_state.rs ├── test_server_builder.rs ├── test_server_config.rs ├── test_web_socket.rs ├── transport.rs ├── transport_layer │ ├── into_transport_layer.rs │ ├── into_transport_layer │ │ ├── axum_service.rs │ │ ├── into_make_service.rs │ │ ├── into_make_service_with_connect_info.rs │ │ ├── router.rs │ │ ├── serve.rs │ │ ├── shuttle_axum.rs │ │ └── with_graceful_shutdown.rs │ ├── mod.rs │ ├── transport_layer.rs │ ├── transport_layer_builder.rs │ └── transport_layer_type.rs └── util │ ├── mod.rs │ ├── new_random_port.rs │ ├── new_random_socket_addr.rs │ ├── new_random_tcp_listener.rs │ ├── new_random_tokio_tcp_listener.rs │ ├── serve_handle.rs │ └── spawn_serve.rs ├── test.sh └── tests └── test-expect-json-integration.rs /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Changes 2 | 3 | * 4 | 5 | # Comments 6 | 7 | Any other business. 8 | 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Cargo 2 | /target/ 3 | /examples/*/target 4 | Cargo.lock 5 | 6 | # Backup files generated by rustfmt 7 | **/*.rs.bk 8 | 9 | # MSVC Windows builds of rustc generate these, which store debugging information 10 | *.pdb 11 | 12 | # Local todo file 13 | /todo.txt 14 | /todo.md 15 | 16 | # MacOS 17 | .DS_Store 18 | 19 | *.swp 20 | 21 | # IDE folders 22 | .idea 23 | .vscode 24 | 25 | # Environment files (just in case) 26 | .env 27 | .env* 28 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "axum-test" 3 | authors = ["Joseph Lenton "] 4 | version = "18.0.0-rc3" 5 | rust-version = "1.83" 6 | edition = "2021" 7 | license = "MIT" 8 | description = "For spinning up and testing Axum servers" 9 | keywords = ["testing", "test", "axum"] 10 | categories = ["web-programming::http-server", "development-tools::testing"] 11 | repository = "https://github.com/JosephLenton/axum-test" 12 | documentation = "https://docs.rs/axum-test" 13 | readme = "README.md" 14 | 15 | [package.metadata.docs.rs] 16 | all-features = true 17 | rustdoc-args = ["--cfg", "docsrs"] 18 | 19 | [[example]] 20 | name = "example-shuttle" 21 | path = "examples/example-shuttle/main.rs" 22 | required-features = ["shuttle"] 23 | 24 | [[example]] 25 | name = "example-websocket-chat" 26 | path = "examples/example-websocket-chat/main.rs" 27 | required-features = ["ws"] 28 | 29 | [[example]] 30 | name = "example-websocket-ping-pong" 31 | path = "examples/example-websocket-ping-pong/main.rs" 32 | required-features = ["ws"] 33 | 34 | [features] 35 | default = ["pretty-assertions"] 36 | 37 | all = ["pretty-assertions", "yaml", "msgpack", "reqwest", "shuttle", "typed-routing", "ws"] 38 | 39 | pretty-assertions = ["dep:pretty_assertions"] 40 | yaml = ["dep:serde_yaml"] 41 | msgpack = ["dep:rmp-serde"] 42 | shuttle = ["dep:shuttle-axum"] 43 | typed-routing = ["dep:axum-extra"] 44 | ws = ["axum/ws", "tokio/time", "dep:uuid", "dep:base64", "dep:tokio-tungstenite", "dep:futures-util"] 45 | reqwest = ["dep:reqwest"] 46 | 47 | # Deprecated, and will be removed in the future. 48 | old-json-diff = ["dep:assert-json-diff"] 49 | 50 | [dependencies] 51 | auto-future = "1.0" 52 | axum = { version = "0.8.4", features = [] } 53 | anyhow = "1.0" 54 | bytes = "1.10" 55 | bytesize = "2.0" 56 | cookie = "0.18" 57 | expect-json = "1.0.0" 58 | http = "1.3" 59 | http-body-util = "0.1" 60 | hyper-util = { version = "0.1", features = ["client", "http1", "client-legacy"] } 61 | hyper = { version = "1.6", features = ["http1"] } 62 | mime = "0.3" 63 | rust-multipart-rfc7578_2 = "0.8" 64 | reserve-port = "2.2" 65 | serde = "1.0" 66 | serde_json = "1.0" 67 | serde_urlencoded = "0.7" 68 | smallvec = "1.13" 69 | tokio = { version = "1.45", features = ["rt"] } 70 | tower = { version = "0.5", features = ["util", "make"] } 71 | url = "2.5" 72 | 73 | # Pretty Assertions 74 | pretty_assertions = { version = "1.4", optional = true } 75 | 76 | # Yaml 77 | serde_yaml = { version = "0.9", optional = true } 78 | 79 | # Shuttle 80 | shuttle-axum = { version = "0.54", optional = true } 81 | 82 | # MsgPack 83 | rmp-serde = { version = "1.3", optional = true } 84 | 85 | # Typed Routing 86 | axum-extra = { version = "0.10", features = ["typed-routing"], optional = true } 87 | 88 | # WebSockets 89 | uuid = { version = "1.12", optional = true, features = ["v4"]} 90 | base64 = { version = "0.22", optional = true } 91 | futures-util = { version = "0.3", optional = true } 92 | tokio-tungstenite = { version = "0.26", optional = true } 93 | 94 | # Reqwest 95 | reqwest = { version = "0.12", optional = true, features = ["cookies", "json", "stream", "multipart", "rustls-tls"] } 96 | 97 | # Old Json Diff 98 | assert-json-diff = { version = "2.0", optional = true } 99 | 100 | [dev-dependencies] 101 | axum = { version = "0.8", features = ["multipart", "tokio", "ws"] } 102 | axum-extra = { version = "0.10", features = ["cookie", "typed-routing", "query"] } 103 | axum-msgpack = "0.5" 104 | axum-yaml = "0.5" 105 | futures-util = "0.3" 106 | local-ip-address = "0.6" 107 | rand = { version = "0.9", features = ["small_rng"] } 108 | regex = "1.11" 109 | serde-email = { version = "3.1", features = ["serde"] } 110 | shuttle-axum = "0.54" 111 | shuttle-runtime = "0.54" 112 | tokio = { version = "1.43", features = ["rt", "rt-multi-thread", "sync", "time", "macros"] } 113 | tower-http = { version = "0.6", features = ["normalize-path"] } 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Joseph Lenton 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Axum Test 4 |

5 | 6 |

7 | Easy E2E testing for Axum
8 | including REST, WebSockets, and more 9 |

10 | 11 | [![crate](https://img.shields.io/crates/v/axum-test.svg)](https://crates.io/crates/axum-test) 12 | [![docs](https://docs.rs/axum-test/badge.svg)](https://docs.rs/axum-test) 13 | 14 |
15 |
16 | 17 | This runs your application locally, allowing you to query against it with requests. 18 | Decode the responses, and assert what is returned. 19 | 20 | ```rust 21 | use axum::Router; 22 | use axum::routing::get; 23 | 24 | use axum_test::TestServer; 25 | 26 | #[tokio::test] 27 | async fn it_should_ping_pong() { 28 | // Build an application with a route. 29 | let app = Router::new() 30 | .route(&"/ping", get(|| async { "pong!" })); 31 | 32 | // Run the application for testing. 33 | let server = TestServer::new(app).unwrap(); 34 | 35 | // Get the request. 36 | let response = server 37 | .get("/ping") 38 | .await; 39 | 40 | // Assertions. 41 | response.assert_status_ok(); 42 | response.assert_text("pong!"); 43 | } 44 | ``` 45 | 46 | A `TestServer` enables you to run an Axum service with a mocked network, 47 | or on a random port with real network reqeusts. 48 | In both cases allowing you to run multiple servers, across multiple tests, all in parallel. 49 | 50 | ## Crate Features 51 | 52 | | Feature | On by default | | 53 | |---------------------|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------| 54 | | `all` | _off_ | Turns on all features. | 55 | | `pretty-assertions` | **on** | Uses the [pretty assertions crate](https://crates.io/crates/pretty_assertions) on response `assert_*` methods. | 56 | | `yaml` | _off_ | Enables support for sending, receiving, and asserting, [yaml content](https://yaml.org/). | 57 | | `msgpack` | _off_ | Enables support for sending, receiving, and asserting, [msgpack content](https://msgpack.org/index.html). | 58 | | `shuttle` | _off_ | Enables support for building a `TestServer` an [`shuttle_axum::AxumService`](https://docs.rs/shuttle-axum/latest/shuttle_axum/struct.AxumService.html), for use with [Shuttle.rs](https://shuttle.rs). | 59 | | `typed-routing` | _off_ | Enables support for using `TypedPath` in requests. See [axum-extra](https://crates.io/crates/axum-extra) for details. | 60 | | `ws` | _off_ | Enables WebSocket support. See [TestWebSocket](https://docs.rs/axum-test/latest/axum_test/struct.TestWebSocket.html) for details. | 61 | | `reqwest` | _off_ | Enables the `TestServer` being able to create [Reqwest](https://docs.rs/axum-test/latest/axum_test/struct.TestWebSocket.html) requests for querying. | 62 | | `old-json-diff` | _off_ | Switches back to the old Json diff behaviour before Axum 18. If you find yourself needing this, then please raise an issue to let me know why. | 63 | 64 | ## Axum Compatability 65 | 66 | The current version of Axum Test requires at least Axum v0.8.3. 67 | 68 | Here is a list of compatability with prior versions: 69 | 70 | | Axum Version | Axum Test Version | 71 | |-----------------|-------------------| 72 | | 0.8.4+ (latest) | 18 (latest) | 73 | | 0.8.3 | 17.3 | 74 | | 0.8.0 | 17 | 75 | | 0.7.6 to 0.7.9 | 16 | 76 | | 0.7.0 to 0.7.5 | 14, 15 | 77 | | 0.6 | 13.4.1 | 78 | 79 | ## Examples 80 | 81 | You can find examples of writing tests in the [/examples folder](/examples/). 82 | These include tests for: 83 | 84 | * [a simple REST Todo application](/examples/example-todo), and [the same using Shuttle](/examples/example-shuttle) 85 | * [a WebSocket ping pong application](/examples/example-websocket-ping-pong) which sends requests up and down 86 | * [a simple WebSocket chat application](/examples/example-websocket-chat) 87 | 88 | ## Request Building Features 89 | 90 | Querying your application on the `TestServer` supports all of the common request building you would expect. 91 | 92 | - Serializing and deserializing Json, Form, Yaml, and others, using Serde 93 | - Assertions on the Json, text, Yaml, etc, that is returned. 94 | - Cookie, query, and header setting and reading 95 | - Status code reading and assertions 96 | 97 | ### Also includes 98 | 99 | - WebSockets testing support 100 | - Saving returned cookies for use on future requests 101 | - Setting headers, query, and cookies, globally for all requests or on per request basis 102 | - Can run requests using a real web server, or with mocked HTTP 103 | - Automatic status assertions for expecting requests to succeed (to help catch bugs in tests sooner) 104 | - Prettified assertion output 105 | - Typed Routing from Axum Extra 106 | - Reqwest integration 107 | 108 | ## Contributions 109 | 110 | A big thanks to all of these who have helped! 111 | 112 | 113 | 114 | 115 | 116 | Made with [contrib.rocks](https://contrib.rocks). 117 | -------------------------------------------------------------------------------- /docs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | cargo +stable doc --features=all --open 6 | -------------------------------------------------------------------------------- /examples/example-shuttle/README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Example REST Todo
4 |

5 | 6 |

7 | an example application with tests 8 |

9 | 10 |
11 |
12 | 13 | This is a very simple todo application. It aims to show ... 14 | 15 | * How to write some basic tests against end points. 16 | * How to test for some tests to be expecting success, and some to be expecting failure. 17 | * How to take cookies into account (like logging in). 18 | 19 | It's primarily to provide some code samples using axum-test. 20 | -------------------------------------------------------------------------------- /examples/example-shuttle/main.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This is an example Todo Application, wrapped with Shuttle. 3 | //! To show some simple tests when using Shuttle + Axum. 4 | //! 5 | //! ```bash 6 | //! # To run it's tests: 7 | //! cargo test --example=example-shuttle --features shuttle 8 | //! ``` 9 | //! 10 | //! The app includes the end points for ... 11 | //! 12 | //! - POST /login ... this takes an email, and returns a session cookie. 13 | //! - PUT /todo ... once logged in, one can store todos. 14 | //! - GET /todo ... once logged in, you can retrieve all todos you have stored. 15 | //! 16 | //! At the bottom of this file are a series of tests for these endpoints. 17 | //! 18 | 19 | use anyhow::anyhow; 20 | use anyhow::Result; 21 | use axum::extract::Json; 22 | use axum::extract::State; 23 | use axum::routing::get; 24 | use axum::routing::post; 25 | use axum::routing::put; 26 | use axum::Router; 27 | use axum_extra::extract::cookie::Cookie; 28 | use axum_extra::extract::cookie::CookieJar; 29 | use http::StatusCode; 30 | use serde::Deserialize; 31 | use serde::Serialize; 32 | use serde_email::Email; 33 | use std::collections::HashMap; 34 | use std::result::Result as StdResult; 35 | use std::sync::Arc; 36 | use std::sync::RwLock; 37 | 38 | #[cfg(test)] 39 | use axum_test::TestServer; 40 | 41 | /// Main to start Shuttle application 42 | #[shuttle_runtime::main] 43 | async fn main() -> ::shuttle_axum::ShuttleAxum { 44 | new_app() 45 | } 46 | 47 | /// The Shuttle application itself 48 | fn new_app() -> ::shuttle_axum::ShuttleAxum { 49 | let state = AppState { 50 | user_todos: HashMap::new(), 51 | }; 52 | let shared_state = Arc::new(RwLock::new(state)); 53 | 54 | let app = Router::new() 55 | .route(&"/login", post(route_post_user_login)) 56 | .route(&"/todo", get(route_get_user_todos)) 57 | .route(&"/todo", put(route_put_user_todos)) 58 | .with_state(shared_state); 59 | 60 | Ok(app.into()) 61 | } 62 | 63 | /// A TestServer that runs the Shuttle application 64 | #[cfg(test)] 65 | fn new_test_app() -> TestServer { 66 | TestServer::builder() 67 | // Preserve cookies across requests 68 | // for the session cookie to work. 69 | .save_cookies() 70 | .expect_success_by_default() 71 | .mock_transport() 72 | .build(new_app()) // <- here the application is passed in 73 | .unwrap() 74 | } 75 | 76 | const USER_ID_COOKIE_NAME: &'static str = &"example-shuttle-user-id"; 77 | 78 | type SharedAppState = Arc>; 79 | 80 | // This my poor mans in memory DB. 81 | #[derive(Debug)] 82 | pub struct AppState { 83 | user_todos: HashMap>, 84 | } 85 | 86 | #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] 87 | pub struct Todo { 88 | name: String, 89 | content: String, 90 | } 91 | 92 | #[derive(Debug, Clone, Deserialize, Serialize)] 93 | pub struct LoginRequest { 94 | user: Email, 95 | } 96 | 97 | #[derive(Debug, Clone, Deserialize, Serialize)] 98 | pub struct AllTodos { 99 | todos: Vec, 100 | } 101 | 102 | #[derive(Debug, Clone, Deserialize, Serialize)] 103 | pub struct NumTodos { 104 | num: u32, 105 | } 106 | 107 | // Note you should never do something like this in a real application 108 | // for session cookies. It's really bad. Like _seriously_ bad. 109 | // 110 | // This is done like this here to keep the code shorter. That's all. 111 | fn get_user_id_from_cookie(cookies: &CookieJar) -> Result { 112 | cookies 113 | .get(&USER_ID_COOKIE_NAME) 114 | .map(|c| c.value().to_string().parse::().ok()) 115 | .flatten() 116 | .ok_or_else(|| anyhow!("id not found")) 117 | } 118 | 119 | pub async fn route_post_user_login( 120 | State(ref mut state): State, 121 | mut cookies: CookieJar, 122 | Json(_body): Json, 123 | ) -> CookieJar { 124 | let mut lock = state.write().unwrap(); 125 | let user_todos = &mut lock.user_todos; 126 | let user_id = user_todos.len() as u32; 127 | user_todos.insert(user_id, vec![]); 128 | 129 | let really_insecure_login_cookie = Cookie::new(USER_ID_COOKIE_NAME, user_id.to_string()); 130 | cookies = cookies.add(really_insecure_login_cookie); 131 | 132 | cookies 133 | } 134 | 135 | pub async fn route_put_user_todos( 136 | State(ref mut state): State, 137 | mut cookies: CookieJar, 138 | Json(todo): Json, 139 | ) -> StdResult, StatusCode> { 140 | let user_id = get_user_id_from_cookie(&mut cookies).map_err(|_| StatusCode::UNAUTHORIZED)?; 141 | 142 | let mut lock = state.write().unwrap(); 143 | let todos = lock.user_todos.get_mut(&user_id).unwrap(); 144 | 145 | todos.push(todo); 146 | let num_todos = todos.len() as u32; 147 | 148 | Ok(Json(num_todos)) 149 | } 150 | 151 | pub async fn route_get_user_todos( 152 | State(ref state): State, 153 | mut cookies: CookieJar, 154 | ) -> StdResult>, StatusCode> { 155 | let user_id = get_user_id_from_cookie(&mut cookies).map_err(|_| StatusCode::UNAUTHORIZED)?; 156 | 157 | let lock = state.read().unwrap(); 158 | let todos = lock.user_todos[&user_id].clone(); 159 | 160 | Ok(Json(todos)) 161 | } 162 | 163 | #[cfg(test)] 164 | mod test_post_login { 165 | use super::*; 166 | 167 | use serde_json::json; 168 | 169 | #[tokio::test] 170 | async fn it_should_create_session_on_login() { 171 | let server = new_test_app(); 172 | 173 | let response = server 174 | .post(&"/login") 175 | .json(&json!({ 176 | "user": "my-login@example.com", 177 | })) 178 | .await; 179 | 180 | let session_cookie = response.cookie(&USER_ID_COOKIE_NAME); 181 | assert_ne!(session_cookie.value(), ""); 182 | } 183 | 184 | #[tokio::test] 185 | async fn it_should_not_login_using_non_email() { 186 | let server = new_test_app(); 187 | 188 | let response = server 189 | .post(&"/login") 190 | .json(&json!({ 191 | "user": "blah blah blah", 192 | })) 193 | .expect_failure() 194 | .await; 195 | 196 | // There should not be a session created. 197 | let cookie = response.maybe_cookie(&USER_ID_COOKIE_NAME); 198 | assert!(cookie.is_none()); 199 | } 200 | } 201 | 202 | #[cfg(test)] 203 | mod test_route_put_user_todos { 204 | use super::*; 205 | 206 | use serde_json::json; 207 | 208 | #[tokio::test] 209 | async fn it_should_not_store_todos_without_login() { 210 | let server = new_test_app(); 211 | 212 | let response = server 213 | .put(&"/todo") 214 | .json(&json!({ 215 | "name": "shopping", 216 | "content": "buy eggs", 217 | })) 218 | .expect_failure() 219 | .await; 220 | 221 | assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED); 222 | } 223 | 224 | #[tokio::test] 225 | async fn it_should_return_number_of_todos_as_more_are_pushed() { 226 | let server = new_test_app(); 227 | 228 | server 229 | .post(&"/login") 230 | .json(&json!({ 231 | "user": "my-login@example.com", 232 | })) 233 | .await; 234 | 235 | let num_todos = server 236 | .put(&"/todo") 237 | .json(&json!({ 238 | "name": "shopping", 239 | "content": "buy eggs", 240 | })) 241 | .await 242 | .json::(); 243 | assert_eq!(num_todos, 1); 244 | 245 | let num_todos = server 246 | .put(&"/todo") 247 | .json(&json!({ 248 | "name": "afternoon", 249 | "content": "buy shoes", 250 | })) 251 | .await 252 | .json::(); 253 | assert_eq!(num_todos, 2); 254 | } 255 | } 256 | 257 | #[cfg(test)] 258 | mod test_route_get_user_todos { 259 | use super::*; 260 | 261 | use serde_json::json; 262 | 263 | #[tokio::test] 264 | async fn it_should_not_return_todos_if_logged_out() { 265 | let server = new_test_app(); 266 | 267 | let response = server 268 | .put(&"/todo") 269 | .json(&json!({ 270 | "name": "shopping", 271 | "content": "buy eggs", 272 | })) 273 | .expect_failure() 274 | .await; 275 | 276 | assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED); 277 | } 278 | 279 | #[tokio::test] 280 | async fn it_should_return_all_todos_when_logged_in() { 281 | let server = new_test_app(); 282 | 283 | server 284 | .post(&"/login") 285 | .json(&json!({ 286 | "user": "my-login@example.com", 287 | })) 288 | .await; 289 | 290 | // Push two todos. 291 | server 292 | .put(&"/todo") 293 | .json(&json!({ 294 | "name": "shopping", 295 | "content": "buy eggs", 296 | })) 297 | .await; 298 | server 299 | .put(&"/todo") 300 | .json(&json!({ 301 | "name": "afternoon", 302 | "content": "buy shoes", 303 | })) 304 | .await; 305 | 306 | // Get all todos out from the server. 307 | let todos = server.get(&"/todo").await.json::>(); 308 | 309 | let expected_todos: Vec = vec![ 310 | Todo { 311 | name: "shopping".to_string(), 312 | content: "buy eggs".to_string(), 313 | }, 314 | Todo { 315 | name: "afternoon".to_string(), 316 | content: "buy shoes".to_string(), 317 | }, 318 | ]; 319 | assert_eq!(todos, expected_todos) 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /examples/example-todo/README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Example REST Todo
4 |

5 | 6 |

7 | an example application with tests 8 |

9 | 10 |
11 |
12 | 13 | This is a very simple todo application. It aims to show ... 14 | 15 | * How to write some basic tests against end points. 16 | * How to test for some tests to be expecting success, and some to be expecting failure. 17 | * How to take cookies into account (like logging in). 18 | 19 | It's primarily to provide some code samples using axum-test. 20 | -------------------------------------------------------------------------------- /examples/example-todo/main.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This is an example Todo Application to show some simple tests. 3 | //! 4 | //! ```bash 5 | //! # To run it's tests: 6 | //! cargo test --example=example-todo 7 | //! ``` 8 | //! 9 | //! The app includes the end points for ... 10 | //! 11 | //! - POST /login ... this takes an email, and returns a session cookie. 12 | //! - PUT /todo ... once logged in, one can store todos. 13 | //! - GET /todo ... once logged in, you can retrieve all todos you have stored. 14 | //! 15 | //! At the bottom of this file are a series of tests for these endpoints. 16 | //! 17 | 18 | use anyhow::anyhow; 19 | use anyhow::Result; 20 | use axum::extract::Json; 21 | use axum::extract::State; 22 | use axum::routing::get; 23 | use axum::routing::post; 24 | use axum::routing::put; 25 | use axum::serve::serve; 26 | use axum::Router; 27 | use axum_extra::extract::cookie::Cookie; 28 | use axum_extra::extract::cookie::CookieJar; 29 | use http::StatusCode; 30 | use serde::Deserialize; 31 | use serde::Serialize; 32 | use serde_email::Email; 33 | use std::collections::HashMap; 34 | use std::net::IpAddr; 35 | use std::net::Ipv4Addr; 36 | use std::net::SocketAddr; 37 | use std::result::Result as StdResult; 38 | use std::sync::Arc; 39 | use std::sync::RwLock; 40 | use tokio::net::TcpListener; 41 | 42 | #[cfg(test)] 43 | use axum_test::TestServer; 44 | 45 | const PORT: u16 = 8080; 46 | const USER_ID_COOKIE_NAME: &'static str = &"example-todo-user-id"; 47 | 48 | #[tokio::main] 49 | async fn main() { 50 | let result: Result<()> = { 51 | let app = new_app(); 52 | 53 | // Start! 54 | let ip_address = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); 55 | let address = SocketAddr::new(ip_address, PORT); 56 | let listener = TcpListener::bind(address).await.unwrap(); 57 | serve(listener, app.into_make_service()).await.unwrap(); 58 | 59 | Ok(()) 60 | }; 61 | 62 | match &result { 63 | Err(err) => eprintln!("{}", err), 64 | _ => {} 65 | }; 66 | } 67 | 68 | type SharedAppState = Arc>; 69 | 70 | // This my poor mans in memory DB. 71 | #[derive(Debug)] 72 | pub struct AppState { 73 | user_todos: HashMap>, 74 | } 75 | 76 | #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] 77 | pub struct Todo { 78 | name: String, 79 | content: String, 80 | } 81 | 82 | #[derive(Debug, Clone, Deserialize, Serialize)] 83 | pub struct LoginRequest { 84 | user: Email, 85 | } 86 | 87 | #[derive(Debug, Clone, Deserialize, Serialize)] 88 | pub struct AllTodos { 89 | todos: Vec, 90 | } 91 | 92 | #[derive(Debug, Clone, Deserialize, Serialize)] 93 | pub struct NumTodos { 94 | num: u32, 95 | } 96 | 97 | // Note you should never do something like this in a real application 98 | // for session cookies. It's really bad. Like _seriously_ bad. 99 | // 100 | // This is done like this here to keep the code shorter. That's all. 101 | fn get_user_id_from_cookie(cookies: &CookieJar) -> Result { 102 | cookies 103 | .get(&USER_ID_COOKIE_NAME) 104 | .map(|c| c.value().to_string().parse::().ok()) 105 | .flatten() 106 | .ok_or_else(|| anyhow!("id not found")) 107 | } 108 | 109 | pub async fn route_post_user_login( 110 | State(ref mut state): State, 111 | mut cookies: CookieJar, 112 | Json(_body): Json, 113 | ) -> CookieJar { 114 | let mut lock = state.write().unwrap(); 115 | let user_todos = &mut lock.user_todos; 116 | let user_id = user_todos.len() as u32; 117 | user_todos.insert(user_id, vec![]); 118 | 119 | let really_insecure_login_cookie = Cookie::new(USER_ID_COOKIE_NAME, user_id.to_string()); 120 | cookies = cookies.add(really_insecure_login_cookie); 121 | 122 | cookies 123 | } 124 | 125 | pub async fn route_put_user_todos( 126 | State(ref mut state): State, 127 | mut cookies: CookieJar, 128 | Json(todo): Json, 129 | ) -> StdResult, StatusCode> { 130 | let user_id = get_user_id_from_cookie(&mut cookies).map_err(|_| StatusCode::UNAUTHORIZED)?; 131 | 132 | let mut lock = state.write().unwrap(); 133 | let todos = lock.user_todos.get_mut(&user_id).unwrap(); 134 | 135 | todos.push(todo); 136 | let num_todos = todos.len() as u32; 137 | 138 | Ok(Json(num_todos)) 139 | } 140 | 141 | pub async fn route_get_user_todos( 142 | State(ref state): State, 143 | mut cookies: CookieJar, 144 | ) -> StdResult>, StatusCode> { 145 | let user_id = get_user_id_from_cookie(&mut cookies).map_err(|_| StatusCode::UNAUTHORIZED)?; 146 | 147 | let lock = state.read().unwrap(); 148 | let todos = lock.user_todos[&user_id].clone(); 149 | 150 | Ok(Json(todos)) 151 | } 152 | 153 | pub(crate) fn new_app() -> Router { 154 | let state = AppState { 155 | user_todos: HashMap::new(), 156 | }; 157 | let shared_state = Arc::new(RwLock::new(state)); 158 | 159 | Router::new() 160 | .route(&"/login", post(route_post_user_login)) 161 | .route(&"/todo", get(route_get_user_todos)) 162 | .route(&"/todo", put(route_put_user_todos)) 163 | .with_state(shared_state) 164 | } 165 | 166 | #[cfg(test)] 167 | fn new_test_app() -> TestServer { 168 | let app = new_app(); 169 | TestServer::builder() 170 | // Preserve cookies across requests 171 | // for the session cookie to work. 172 | .save_cookies() 173 | .expect_success_by_default() 174 | .mock_transport() 175 | .build(app) 176 | .unwrap() 177 | } 178 | 179 | #[cfg(test)] 180 | mod test_post_login { 181 | use super::*; 182 | 183 | use serde_json::json; 184 | 185 | #[tokio::test] 186 | async fn it_should_create_session_on_login() { 187 | let server = new_test_app(); 188 | 189 | let response = server 190 | .post(&"/login") 191 | .json(&json!({ 192 | "user": "my-login@example.com", 193 | })) 194 | .await; 195 | 196 | let session_cookie = response.cookie(&USER_ID_COOKIE_NAME); 197 | assert_ne!(session_cookie.value(), ""); 198 | } 199 | 200 | #[tokio::test] 201 | async fn it_should_not_login_using_non_email() { 202 | let server = new_test_app(); 203 | 204 | let response = server 205 | .post(&"/login") 206 | .json(&json!({ 207 | "user": "blah blah blah", 208 | })) 209 | .expect_failure() 210 | .await; 211 | 212 | // There should not be a session created. 213 | let cookie = response.maybe_cookie(&USER_ID_COOKIE_NAME); 214 | assert!(cookie.is_none()); 215 | } 216 | } 217 | 218 | #[cfg(test)] 219 | mod test_route_put_user_todos { 220 | use super::*; 221 | 222 | use serde_json::json; 223 | 224 | #[tokio::test] 225 | async fn it_should_not_store_todos_without_login() { 226 | let server = new_test_app(); 227 | 228 | let response = server 229 | .put(&"/todo") 230 | .json(&json!({ 231 | "name": "shopping", 232 | "content": "buy eggs", 233 | })) 234 | .expect_failure() 235 | .await; 236 | 237 | assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED); 238 | } 239 | 240 | #[tokio::test] 241 | async fn it_should_return_number_of_todos_as_more_are_pushed() { 242 | let server = new_test_app(); 243 | 244 | server 245 | .post(&"/login") 246 | .json(&json!({ 247 | "user": "my-login@example.com", 248 | })) 249 | .await; 250 | 251 | let num_todos = server 252 | .put(&"/todo") 253 | .json(&json!({ 254 | "name": "shopping", 255 | "content": "buy eggs", 256 | })) 257 | .await 258 | .json::(); 259 | assert_eq!(num_todos, 1); 260 | 261 | let num_todos = server 262 | .put(&"/todo") 263 | .json(&json!({ 264 | "name": "afternoon", 265 | "content": "buy shoes", 266 | })) 267 | .await 268 | .json::(); 269 | assert_eq!(num_todos, 2); 270 | } 271 | } 272 | 273 | #[cfg(test)] 274 | mod test_route_get_user_todos { 275 | use super::*; 276 | 277 | use serde_json::json; 278 | 279 | #[tokio::test] 280 | async fn it_should_not_return_todos_if_logged_out() { 281 | let server = new_test_app(); 282 | 283 | let response = server 284 | .put(&"/todo") 285 | .json(&json!({ 286 | "name": "shopping", 287 | "content": "buy eggs", 288 | })) 289 | .expect_failure() 290 | .await; 291 | 292 | assert_eq!(response.status_code(), StatusCode::UNAUTHORIZED); 293 | } 294 | 295 | #[tokio::test] 296 | async fn it_should_return_all_todos_when_logged_in() { 297 | let server = new_test_app(); 298 | 299 | server 300 | .post(&"/login") 301 | .json(&json!({ 302 | "user": "my-login@example.com", 303 | })) 304 | .await; 305 | 306 | // Push two todos. 307 | server 308 | .put(&"/todo") 309 | .json(&json!({ 310 | "name": "shopping", 311 | "content": "buy eggs", 312 | })) 313 | .await; 314 | server 315 | .put(&"/todo") 316 | .json(&json!({ 317 | "name": "afternoon", 318 | "content": "buy shoes", 319 | })) 320 | .await; 321 | 322 | // Get all todos out from the server. 323 | let todos = server.get(&"/todo").await.json::>(); 324 | 325 | let expected_todos: Vec = vec![ 326 | Todo { 327 | name: "shopping".to_string(), 328 | content: "buy eggs".to_string(), 329 | }, 330 | Todo { 331 | name: "afternoon".to_string(), 332 | content: "buy shoes".to_string(), 333 | }, 334 | ]; 335 | assert_eq!(todos, expected_todos) 336 | } 337 | } 338 | -------------------------------------------------------------------------------- /examples/example-websocket-chat/README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Example WebSockets Chat
4 |

5 | 6 |

7 | a simple chat application with tests 8 |

9 | 10 |
11 |
12 | 13 | This is a very simple application using WebSockets. It aims to show ... 14 | 15 | * How to write a very basic chat application, 16 | * and include tests which send and receive data. 17 | 18 | It's primarily to provide some code samples using axum-test. 19 | -------------------------------------------------------------------------------- /examples/example-websocket-chat/main.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This is an example Todo Application using Web Sockets for communication. 3 | //! 4 | //! At the bottom of this file are a series of tests for using websockets. 5 | //! 6 | //! ```bash 7 | //! # To run it's tests: 8 | //! cargo test --example=example-websocket-chat --features ws 9 | //! ``` 10 | //! 11 | 12 | use anyhow::Result; 13 | use axum::extract::ws::Message; 14 | use axum::extract::ws::WebSocket; 15 | use axum::extract::Path; 16 | use axum::extract::State; 17 | use axum::extract::WebSocketUpgrade; 18 | use axum::response::Response; 19 | use axum::routing::get; 20 | use axum::serve::serve; 21 | use axum::Router; 22 | use futures_util::SinkExt; 23 | use futures_util::StreamExt; 24 | use serde::Deserialize; 25 | use serde::Serialize; 26 | use std::collections::HashMap; 27 | use std::net::IpAddr; 28 | use std::net::Ipv4Addr; 29 | use std::net::SocketAddr; 30 | use std::sync::Arc; 31 | use std::time::Duration; 32 | use tokio::net::TcpListener; 33 | use tokio::sync::RwLock; 34 | 35 | #[cfg(test)] 36 | use axum_test::TestServer; 37 | 38 | const PORT: u16 = 8080; 39 | 40 | #[tokio::main] 41 | async fn main() { 42 | let result: Result<()> = { 43 | let app = new_app(); 44 | 45 | // Start! 46 | let ip_address = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); 47 | let address = SocketAddr::new(ip_address, PORT); 48 | let listener = TcpListener::bind(address).await.unwrap(); 49 | serve(listener, app.into_make_service()).await.unwrap(); 50 | 51 | Ok(()) 52 | }; 53 | 54 | match &result { 55 | Err(err) => eprintln!("{}", err), 56 | _ => {} 57 | }; 58 | } 59 | 60 | type SharedAppState = Arc>; 61 | 62 | /// This my poor mans chat system. 63 | /// 64 | /// It holds a map of User ID to Messages. 65 | #[derive(Debug)] 66 | pub struct AppState { 67 | user_messages: HashMap>, 68 | } 69 | 70 | #[derive(Deserialize, Serialize, Debug, PartialEq)] 71 | pub struct ChatSendMessage { 72 | pub to: String, 73 | pub message: String, 74 | } 75 | 76 | #[derive(Deserialize, Serialize, Debug, PartialEq)] 77 | pub struct ChatReceivedMessage { 78 | pub from: String, 79 | pub message: String, 80 | } 81 | 82 | pub async fn route_get_websocket_chat( 83 | State(state): State, 84 | Path(username): Path, 85 | ws: WebSocketUpgrade, 86 | ) -> Response { 87 | ws.on_upgrade(move |socket| handle_chat(socket, username, state.clone())) 88 | } 89 | 90 | async fn handle_chat(socket: WebSocket, username: String, state: SharedAppState) { 91 | let (mut sender, mut receiver) = socket.split(); 92 | 93 | // Spawn a task that will push several messages to the client (does not matter what client does) 94 | let send_state = state.clone(); 95 | let send_username = username.clone(); 96 | let mut send_task = tokio::spawn(async move { 97 | loop { 98 | let mut state_locked = send_state.write().await; 99 | let maybe_messages = state_locked.user_messages.get_mut(&send_username); 100 | 101 | if let Some(messages) = maybe_messages { 102 | while let Some(message) = messages.pop() { 103 | let json_text = serde_json::to_string(&message) 104 | .expect("Failed to build JSON message for sending"); 105 | 106 | sender 107 | .send(Message::Text(json_text.into())) 108 | .await 109 | .expect("Failed to send message to socket"); 110 | } 111 | } 112 | 113 | ::tokio::time::sleep(Duration::from_millis(10)).await; 114 | } 115 | }); 116 | 117 | // This second task will receive messages from client and print them on server console 118 | let mut recv_task = tokio::spawn(async move { 119 | while let Some(Ok(message)) = receiver.next().await { 120 | let raw_text = message 121 | .into_text() 122 | .expect("Failed to read text from incoming message"); 123 | let decoded = serde_json::from_str::(&raw_text) 124 | .expect("Failed to decode incoming JSON message"); 125 | 126 | let mut state_locked = state.write().await; 127 | let maybe_messages = state_locked.user_messages.entry(decoded.to); 128 | maybe_messages.or_default().push(ChatReceivedMessage { 129 | from: username.clone(), 130 | message: decoded.message, 131 | }); 132 | } 133 | }); 134 | 135 | // If any one of the tasks exit, abort the other. 136 | tokio::select! { 137 | rv_a = (&mut send_task) => { 138 | match rv_a { 139 | Ok(_) => println!("Messages sent"), 140 | Err(a) => println!("Error sending messages {a:?}") 141 | } 142 | recv_task.abort(); 143 | }, 144 | rv_b = (&mut recv_task) => { 145 | match rv_b { 146 | Ok(_) => println!("Received messages"), 147 | Err(b) => println!("Error receiving messages {b:?}") 148 | } 149 | send_task.abort(); 150 | } 151 | } 152 | } 153 | 154 | pub(crate) fn new_app() -> Router { 155 | let state = AppState { 156 | user_messages: HashMap::new(), 157 | }; 158 | let shared_state = Arc::new(RwLock::new(state)); 159 | 160 | Router::new() 161 | .route(&"/ws-chat/{name}", get(route_get_websocket_chat)) 162 | .with_state(shared_state) 163 | } 164 | 165 | #[cfg(test)] 166 | fn new_test_app() -> TestServer { 167 | let app = new_app(); 168 | TestServer::builder() 169 | .http_transport() // Important! It must be a HTTP Transport here. 170 | .build(app) 171 | .unwrap() 172 | } 173 | 174 | #[cfg(test)] 175 | mod test_websockets_chat { 176 | use super::*; 177 | 178 | #[tokio::test] 179 | async fn it_should_start_a_websocket_connection() { 180 | let server = new_test_app(); 181 | 182 | let response = server.get_websocket(&"/ws-chat/john").await; 183 | 184 | response.assert_status_switching_protocols(); 185 | } 186 | 187 | #[tokio::test] 188 | async fn it_should_send_messages_back_and_forth() { 189 | let server = new_test_app(); 190 | 191 | let mut alice_chat = server 192 | .get_websocket(&"/ws-chat/alice") 193 | .await 194 | .into_websocket() 195 | .await; 196 | 197 | let mut bob_chat = server 198 | .get_websocket(&"/ws-chat/bob") 199 | .await 200 | .into_websocket() 201 | .await; 202 | 203 | bob_chat 204 | .send_json(&ChatSendMessage { 205 | to: "alice".to_string(), 206 | message: "How are you Alice?".to_string(), 207 | }) 208 | .await; 209 | 210 | alice_chat 211 | .assert_receive_json(&ChatReceivedMessage { 212 | from: "bob".to_string(), 213 | message: "How are you Alice?".to_string(), 214 | }) 215 | .await; 216 | alice_chat 217 | .send_json(&ChatSendMessage { 218 | to: "bob".to_string(), 219 | message: "I am good".to_string(), 220 | }) 221 | .await; 222 | alice_chat 223 | .send_json(&ChatSendMessage { 224 | to: "bob".to_string(), 225 | message: "How are you?".to_string(), 226 | }) 227 | .await; 228 | 229 | bob_chat 230 | .assert_receive_json(&ChatReceivedMessage { 231 | from: "alice".to_string(), 232 | message: "I am good".to_string(), 233 | }) 234 | .await; 235 | bob_chat 236 | .assert_receive_json(&ChatReceivedMessage { 237 | from: "alice".to_string(), 238 | message: "How are you?".to_string(), 239 | }) 240 | .await; 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /examples/example-websocket-ping-pong/README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | Example WebSockets Ping Pong
4 |

5 | 6 |

7 | an example websocket application with tests 8 |

9 | 10 |
11 |
12 | 13 | This is a very simple application using WebSockets. It aims to show ... 14 | 15 | * How to write a basic test that starts a WebSocket connection. 16 | * A basic ping pong test, where data is pushed up and down. 17 | 18 | It's primarily to provide some code samples using axum-test. 19 | -------------------------------------------------------------------------------- /examples/example-websocket-ping-pong/main.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This is a simple WebSocket example Application. 3 | //! You send it data, and it will send it back. 4 | //! 5 | //! At the bottom of this file are a series of tests for using websockets. 6 | //! 7 | //! ```bash 8 | //! # To run it's tests: 9 | //! cargo test --example=example-websocket-ping-pong --features ws 10 | //! ``` 11 | //! 12 | 13 | use anyhow::Result; 14 | use axum::extract::ws::WebSocket; 15 | use axum::extract::WebSocketUpgrade; 16 | use axum::response::Response; 17 | use axum::routing::get; 18 | use axum::serve::serve; 19 | use axum::Router; 20 | use std::net::IpAddr; 21 | use std::net::Ipv4Addr; 22 | use std::net::SocketAddr; 23 | use tokio::net::TcpListener; 24 | 25 | #[cfg(test)] 26 | use axum_test::TestServer; 27 | 28 | const PORT: u16 = 8080; 29 | 30 | #[tokio::main] 31 | async fn main() { 32 | let result: Result<()> = { 33 | let app = new_app(); 34 | 35 | // Start! 36 | let ip_address = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); 37 | let address = SocketAddr::new(ip_address, PORT); 38 | let listener = TcpListener::bind(address).await.unwrap(); 39 | serve(listener, app.into_make_service()).await.unwrap(); 40 | 41 | Ok(()) 42 | }; 43 | 44 | match &result { 45 | Err(err) => eprintln!("{}", err), 46 | _ => {} 47 | }; 48 | } 49 | 50 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 51 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 52 | } 53 | 54 | async fn handle_ping_pong(mut socket: WebSocket) { 55 | while let Some(msg) = socket.recv().await { 56 | let msg = if let Ok(msg) = msg { 57 | msg 58 | } else { 59 | // client disconnected 60 | return; 61 | }; 62 | 63 | if socket.send(msg).await.is_err() { 64 | // client disconnected 65 | return; 66 | } 67 | } 68 | } 69 | 70 | pub(crate) fn new_app() -> Router { 71 | Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)) 72 | } 73 | 74 | #[cfg(test)] 75 | fn new_test_app() -> TestServer { 76 | let app = new_app(); 77 | TestServer::builder() 78 | .http_transport() // Important! It must be a HTTP Transport here. 79 | .build(app) 80 | .unwrap() 81 | } 82 | 83 | #[cfg(test)] 84 | mod test_websockets_ping_pong { 85 | use super::*; 86 | 87 | use serde_json::json; 88 | 89 | #[tokio::test] 90 | async fn it_should_start_a_websocket_connection() { 91 | let server = new_test_app(); 92 | 93 | let response = server.get_websocket(&"/ws-ping-pong").await; 94 | 95 | response.assert_status_switching_protocols(); 96 | } 97 | 98 | #[tokio::test] 99 | async fn it_should_ping_pong_text() { 100 | let server = new_test_app(); 101 | 102 | let mut websocket = server 103 | .get_websocket(&"/ws-ping-pong") 104 | .await 105 | .into_websocket() 106 | .await; 107 | 108 | websocket.send_text("Hello!").await; 109 | websocket.assert_receive_text("Hello!").await; 110 | } 111 | 112 | #[tokio::test] 113 | async fn it_should_ping_pong_json() { 114 | let server = new_test_app(); 115 | 116 | let mut websocket = server 117 | .get_websocket(&"/ws-ping-pong") 118 | .await 119 | .into_websocket() 120 | .await; 121 | 122 | websocket 123 | .send_json(&json!({ 124 | "hello": "world", 125 | "numbers": [1, 2, 3], 126 | })) 127 | .await; 128 | websocket 129 | .assert_receive_json(&json!({ 130 | "hello": "world", 131 | "numbers": [1, 2, 3], 132 | })) 133 | .await; 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /files/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Joe", 3 | "age": 20 4 | } 5 | -------------------------------------------------------------------------------- /files/example.txt: -------------------------------------------------------------------------------- 1 | hello! -------------------------------------------------------------------------------- /files/example.yaml: -------------------------------------------------------------------------------- 1 | name: Joe 2 | age: 20 -------------------------------------------------------------------------------- /rust-toolchain: -------------------------------------------------------------------------------- 1 | stable -------------------------------------------------------------------------------- /src/expect_json.rs: -------------------------------------------------------------------------------- 1 | pub use ::expect_json::expect; 2 | 3 | /// This macro is for defining your own custom [`ExpectOp`] checks. 4 | #[doc(inline)] 5 | pub use ::expect_json::expect_op_for_axum_test as expect_op; 6 | 7 | pub use ::expect_json::ops; 8 | pub use ::expect_json::Context; 9 | pub use ::expect_json::ExpectJsonError; 10 | pub use ::expect_json::ExpectJsonResult; 11 | pub use ::expect_json::ExpectOp; 12 | pub use ::expect_json::ExpectOpError; 13 | pub use ::expect_json::ExpectOpExt; 14 | pub use ::expect_json::ExpectOpResult; 15 | pub use ::expect_json::JsonType; 16 | 17 | #[doc(hidden)] 18 | pub use ::expect_json::ExpectOpSerialize; 19 | #[doc(hidden)] 20 | pub use ::expect_json::SerializeExpectOp; 21 | #[doc(hidden)] 22 | pub use ::expect_json::__private; 23 | -------------------------------------------------------------------------------- /src/internals/debug_response_body.rs: -------------------------------------------------------------------------------- 1 | use crate::TestResponse; 2 | use bytesize::ByteSize; 3 | use std::fmt::Display; 4 | use std::fmt::Formatter; 5 | use std::fmt::Result as FmtResult; 6 | 7 | /// An arbituary limit to avoid printing gigabytes to the terminal. 8 | const MAX_TEXT_PRINT_LEN: usize = 10_000; 9 | 10 | #[derive(Debug)] 11 | pub struct DebugResponseBody<'a>(pub &'a TestResponse); 12 | 13 | impl Display for DebugResponseBody<'_> { 14 | fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { 15 | match self.0.maybe_content_type() { 16 | Some(content_type) => { 17 | match content_type.as_str() { 18 | // Json 19 | "application/json" | "text/json" => write_json(f, self.0), 20 | 21 | // Msgpack 22 | "application/msgpack" => write!(f, ""), 23 | 24 | // Yaml 25 | #[cfg(feature = "yaml")] 26 | "application/yaml" | "application/x-yaml" | "text/yaml" => { 27 | write_yaml(f, self.0) 28 | } 29 | 30 | #[cfg(not(feature = "yaml"))] 31 | "application/yaml" | "application/x-yaml" | "text/yaml" => { 32 | write_text(f, &self.0.text()) 33 | } 34 | 35 | // Text Content 36 | s if s.starts_with("text/") => write_text(f, &self.0.text()), 37 | 38 | // Byte Streams 39 | "application/octet-stream" => { 40 | let len = self.0.as_bytes().len(); 41 | write!(f, "", ByteSize(len as u64)) 42 | } 43 | 44 | // Unknown content type 45 | _ => { 46 | let len = self.0.as_bytes().len(); 47 | write!( 48 | f, 49 | "", 50 | ByteSize(len as u64) 51 | ) 52 | } 53 | } 54 | } 55 | 56 | // We just default to text 57 | _ => write_text(f, &self.0.text()), 58 | } 59 | } 60 | } 61 | 62 | fn write_text(f: &mut Formatter<'_>, text: &str) -> FmtResult { 63 | let len = text.len(); 64 | 65 | if len < MAX_TEXT_PRINT_LEN { 66 | write!(f, "'{}'", text) 67 | } else { 68 | let text_start = text.chars().take(MAX_TEXT_PRINT_LEN); 69 | write!(f, "'")?; 70 | for c in text_start { 71 | write!(f, "{c}")?; 72 | } 73 | write!(f, "...'")?; 74 | 75 | Ok(()) 76 | } 77 | } 78 | 79 | fn write_json(f: &mut Formatter<'_>, response: &TestResponse) -> FmtResult { 80 | let bytes = response.as_bytes(); 81 | let result = serde_json::from_slice::(bytes); 82 | 83 | match result { 84 | Err(_) => { 85 | write!( 86 | f, 87 | "!!! YOUR JSON IS MALFORMED !!!\nBody: '{}'", 88 | response.text() 89 | ) 90 | } 91 | Ok(body) => { 92 | let pretty_raw = serde_json::to_string_pretty(&body) 93 | .expect("Failed to reserialise serde_json::Value of request body"); 94 | write!(f, "{pretty_raw}") 95 | } 96 | } 97 | } 98 | 99 | #[cfg(feature = "yaml")] 100 | fn write_yaml(f: &mut Formatter<'_>, response: &TestResponse) -> FmtResult { 101 | let response_bytes = response.as_bytes(); 102 | let result = serde_yaml::from_slice::(response_bytes); 103 | 104 | match result { 105 | Err(_) => { 106 | write!( 107 | f, 108 | "!!! YOUR YAML IS MALFORMED !!!\nBody: '{}'", 109 | response.text() 110 | ) 111 | } 112 | Ok(body) => { 113 | let pretty_raw = serde_yaml::to_string(&body) 114 | .expect("Failed to reserialise serde_yaml::Value of request body"); 115 | write!(f, "{pretty_raw}") 116 | } 117 | } 118 | } 119 | 120 | #[cfg(test)] 121 | mod test_fmt { 122 | use super::*; 123 | use crate::TestServer; 124 | use axum::body::Body; 125 | use axum::response::IntoResponse; 126 | use axum::response::Response; 127 | use axum::routing::get; 128 | use axum::Json; 129 | use axum::Router; 130 | use http::header; 131 | use http::HeaderValue; 132 | use pretty_assertions::assert_eq; 133 | use serde::Deserialize; 134 | use serde::Serialize; 135 | 136 | #[derive(Serialize, Deserialize, PartialEq, Debug)] 137 | struct ExampleResponse { 138 | name: String, 139 | age: u32, 140 | } 141 | 142 | #[tokio::test] 143 | async fn it_should_display_text_response_as_text() { 144 | let router = Router::new().route("/text", get(|| async { "Blah blah" })); 145 | let response = TestServer::new(router).unwrap().get("/text").await; 146 | 147 | let debug_body = DebugResponseBody(&response); 148 | let output = format!("{debug_body}"); 149 | 150 | assert_eq!(output, "'Blah blah'"); 151 | } 152 | 153 | #[tokio::test] 154 | async fn it_should_cutoff_very_long_text() { 155 | let router = Router::new().route( 156 | "/text", 157 | get(|| async { 158 | let max_len = MAX_TEXT_PRINT_LEN + 100; 159 | (0..max_len).map(|_| "🦊").collect::() 160 | }), 161 | ); 162 | let response = TestServer::new(router).unwrap().get("/text").await; 163 | 164 | let debug_body = DebugResponseBody(&response); 165 | let output = format!("{debug_body}"); 166 | 167 | let expected_content = (0..MAX_TEXT_PRINT_LEN).map(|_| "🦊").collect::(); 168 | let expected = format!("'{expected_content}...'"); 169 | 170 | assert_eq!(output, expected); 171 | } 172 | 173 | #[tokio::test] 174 | async fn it_should_pretty_print_json() { 175 | let router = Router::new().route( 176 | "/json", 177 | get(|| async { 178 | Json(ExampleResponse { 179 | name: "Joe".to_string(), 180 | age: 20, 181 | }) 182 | }), 183 | ); 184 | let response = TestServer::new(router).unwrap().get("/json").await; 185 | 186 | let debug_body = DebugResponseBody(&response); 187 | let output = format!("{debug_body}"); 188 | let expected = r###"{ 189 | "age": 20, 190 | "name": "Joe" 191 | }"###; 192 | 193 | assert_eq!(output, expected); 194 | } 195 | 196 | #[tokio::test] 197 | async fn it_should_warn_malformed_json() { 198 | let router = Router::new().route( 199 | "/json", 200 | get(|| async { 201 | let body = Body::new(r###"{ "name": "Joe" "###.to_string()); 202 | 203 | Response::builder() 204 | .header( 205 | header::CONTENT_TYPE, 206 | HeaderValue::from_static("application/json"), 207 | ) 208 | .body(body) 209 | .unwrap() 210 | .into_response() 211 | }), 212 | ); 213 | let response = TestServer::new(router).unwrap().get("/json").await; 214 | 215 | let debug_body = DebugResponseBody(&response); 216 | let output = format!("{debug_body}"); 217 | let expected = r###"!!! YOUR JSON IS MALFORMED !!! 218 | Body: '{ "name": "Joe" '"###; 219 | 220 | assert_eq!(output, expected); 221 | } 222 | 223 | #[cfg(feature = "yaml")] 224 | #[tokio::test] 225 | async fn it_should_pretty_print_yaml() { 226 | use axum_yaml::Yaml; 227 | 228 | let router = Router::new().route( 229 | "/yaml", 230 | get(|| async { 231 | Yaml(ExampleResponse { 232 | name: "Joe".to_string(), 233 | age: 20, 234 | }) 235 | }), 236 | ); 237 | let response = TestServer::new(router).unwrap().get("/yaml").await; 238 | 239 | let debug_body = DebugResponseBody(&response); 240 | let output = format!("{debug_body}"); 241 | let expected = r###"name: Joe 242 | age: 20 243 | "###; 244 | 245 | assert_eq!(output, expected); 246 | } 247 | 248 | #[cfg(feature = "yaml")] 249 | #[tokio::test] 250 | async fn it_should_warn_on_malformed_yaml() { 251 | let router = Router::new().route( 252 | "/yaml", 253 | get(|| async { 254 | let body = Body::new("🦊 🦊 🦊: : :🦊 🦊 🦊".to_string()); 255 | 256 | Response::builder() 257 | .header( 258 | header::CONTENT_TYPE, 259 | HeaderValue::from_static("application/yaml"), 260 | ) 261 | .body(body) 262 | .unwrap() 263 | .into_response() 264 | }), 265 | ); 266 | let response = TestServer::new(router).unwrap().get("/yaml").await; 267 | 268 | let debug_body = DebugResponseBody(&response); 269 | let output = format!("{debug_body}"); 270 | let expected = r###"!!! YOUR YAML IS MALFORMED !!! 271 | Body: '🦊 🦊 🦊: : :🦊 🦊 🦊'"###; 272 | 273 | assert_eq!(output, expected); 274 | } 275 | } 276 | -------------------------------------------------------------------------------- /src/internals/expected_state.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, PartialEq, Clone, Copy, Eq, Hash)] 2 | pub enum ExpectedState { 3 | Success, 4 | Failure, 5 | None, 6 | } 7 | 8 | impl From> for ExpectedState { 9 | fn from(maybe_success: Option) -> Self { 10 | match maybe_success { 11 | None => Self::None, 12 | Some(true) => Self::Success, 13 | Some(false) => Self::Failure, 14 | } 15 | } 16 | } 17 | 18 | #[cfg(test)] 19 | mod test_from { 20 | use super::*; 21 | 22 | #[test] 23 | fn it_should_turn_none_to_none() { 24 | let output = ExpectedState::from(None); 25 | assert_eq!(output, ExpectedState::None); 26 | } 27 | 28 | #[test] 29 | fn it_should_turn_true_to_success() { 30 | let output = ExpectedState::from(Some(true)); 31 | assert_eq!(output, ExpectedState::Success); 32 | } 33 | 34 | #[test] 35 | fn it_should_turn_false_to_failure() { 36 | let output = ExpectedState::from(Some(false)); 37 | assert_eq!(output, ExpectedState::Failure); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/internals/format_status_code_range.rs: -------------------------------------------------------------------------------- 1 | use http::StatusCode; 2 | use std::fmt::Write; 3 | use std::ops::Bound; 4 | use std::ops::RangeBounds; 5 | 6 | pub fn format_status_code_range(range: R) -> String 7 | where 8 | R: RangeBounds, 9 | { 10 | let mut output = String::new(); 11 | 12 | let start = range.start_bound(); 13 | let end = range.end_bound(); 14 | 15 | match start { 16 | Bound::Included(code) | Bound::Excluded(code) => { 17 | write!(output, "{}", code.as_u16()).expect("Failed to build debug string"); 18 | } 19 | Bound::Unbounded => {} 20 | }; 21 | 22 | write!(output, "..").expect("Failed to build debug string"); 23 | 24 | match end { 25 | Bound::Included(code) => { 26 | write!(output, "={}", code.as_u16()).expect("Failed to build debug string"); 27 | } 28 | Bound::Excluded(code) => { 29 | write!(output, "{}", code.as_u16()).expect("Failed to build debug string"); 30 | } 31 | Bound::Unbounded => {} 32 | }; 33 | 34 | output 35 | } 36 | 37 | #[cfg(test)] 38 | mod test_format_status_code_range { 39 | use super::*; 40 | 41 | #[test] 42 | fn it_should_format_range() { 43 | let output = format_status_code_range(StatusCode::OK..StatusCode::IM_USED); 44 | assert_eq!(output, "200..226"); 45 | } 46 | 47 | #[test] 48 | fn it_should_format_range_inclusive() { 49 | let output = format_status_code_range(StatusCode::OK..=StatusCode::IM_USED); 50 | assert_eq!(output, "200..=226"); 51 | } 52 | 53 | #[test] 54 | fn it_should_format_range_from() { 55 | let output = format_status_code_range(StatusCode::OK..); 56 | assert_eq!(output, "200.."); 57 | } 58 | 59 | #[test] 60 | fn it_should_format_range_to() { 61 | let output = format_status_code_range(..StatusCode::IM_USED); 62 | assert_eq!(output, "..226"); 63 | } 64 | 65 | #[test] 66 | fn it_should_format_range_to_inclusive() { 67 | let output = format_status_code_range(..=StatusCode::IM_USED); 68 | assert_eq!(output, "..=226"); 69 | } 70 | 71 | #[test] 72 | fn it_should_format_range_full() { 73 | let output = format_status_code_range(..); 74 | assert_eq!(output, ".."); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/internals/mod.rs: -------------------------------------------------------------------------------- 1 | mod transport_layer; 2 | pub use self::transport_layer::*; 3 | 4 | #[cfg(feature = "ws")] 5 | mod websockets; 6 | #[cfg(feature = "ws")] 7 | pub use self::websockets::*; 8 | 9 | mod debug_response_body; 10 | pub use self::debug_response_body::*; 11 | 12 | mod expected_state; 13 | pub use self::expected_state::*; 14 | 15 | mod format_status_code_range; 16 | pub use self::format_status_code_range::*; 17 | 18 | mod status_code_formatter; 19 | pub use self::status_code_formatter::*; 20 | 21 | mod request_path_formatter; 22 | pub use self::request_path_formatter::*; 23 | 24 | mod query_params_store; 25 | pub use self::query_params_store::*; 26 | 27 | mod try_into_range_bounds; 28 | pub use self::try_into_range_bounds::*; 29 | 30 | mod starting_tcp_setup; 31 | pub use self::starting_tcp_setup::*; 32 | 33 | mod with_this_mut; 34 | pub use self::with_this_mut::*; 35 | -------------------------------------------------------------------------------- /src/internals/query_params_store.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use serde::Serialize; 3 | use smallvec::SmallVec; 4 | use std::fmt::Display; 5 | use std::fmt::Formatter; 6 | use std::fmt::Result as FmtResult; 7 | 8 | #[derive(Debug, Clone, PartialEq)] 9 | pub struct QueryParamsStore { 10 | query_params: SmallVec<[String; 0]>, 11 | } 12 | 13 | impl QueryParamsStore { 14 | pub fn new() -> Self { 15 | Self { 16 | query_params: SmallVec::new(), 17 | } 18 | } 19 | 20 | pub fn add(&mut self, query_params: V) -> Result<()> 21 | where 22 | V: Serialize, 23 | { 24 | let value_raw = ::serde_urlencoded::to_string(query_params)?; 25 | self.add_raw(value_raw); 26 | 27 | Ok(()) 28 | } 29 | 30 | pub fn add_raw(&mut self, value_raw: String) { 31 | self.query_params.push(value_raw); 32 | } 33 | 34 | pub fn clear(&mut self) { 35 | self.query_params.clear(); 36 | } 37 | 38 | pub fn is_empty(&self) -> bool { 39 | self.query_params.is_empty() 40 | } 41 | 42 | pub fn has_content(&self) -> bool { 43 | !self.is_empty() 44 | } 45 | } 46 | 47 | impl Display for QueryParamsStore { 48 | fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { 49 | let mut is_joining = false; 50 | for query in &self.query_params { 51 | if is_joining { 52 | write!(f, "&")?; 53 | } 54 | 55 | write!(f, "{}", query)?; 56 | is_joining = true; 57 | } 58 | 59 | Ok(()) 60 | } 61 | } 62 | 63 | #[cfg(test)] 64 | mod test_add { 65 | use super::*; 66 | 67 | #[test] 68 | fn it_should_add_multiple_key_values() { 69 | let mut params = QueryParamsStore::new(); 70 | 71 | params 72 | .add(&[("key", "value"), ("another", "value")]) 73 | .unwrap(); 74 | 75 | assert_eq!("key=value&another=value", params.to_string()); 76 | } 77 | 78 | #[test] 79 | fn it_should_add_multiple_calls() { 80 | let mut params = QueryParamsStore::new(); 81 | 82 | params.add(&[("key", "value")]).unwrap(); 83 | params.add(&[("another", "value")]).unwrap(); 84 | 85 | assert_eq!("key=value&another=value", params.to_string()); 86 | } 87 | 88 | #[test] 89 | fn it_should_reject_raw_string() { 90 | let mut params = QueryParamsStore::new(); 91 | 92 | let result = params.add("key=value"); 93 | 94 | assert!(result.is_err()); 95 | } 96 | 97 | #[test] 98 | fn it_should_add_query_param_strings_deserialized() { 99 | let mut params = QueryParamsStore::new(); 100 | 101 | params.add(&[("key", "value&another=value")]).unwrap(); 102 | 103 | assert_eq!("key=value%26another%3Dvalue", params.to_string()); 104 | } 105 | } 106 | 107 | #[cfg(test)] 108 | mod test_add_raw { 109 | use crate::internals::QueryParamsStore; 110 | 111 | #[test] 112 | fn it_should_add_key_value_pairs_correctly() { 113 | let mut params = QueryParamsStore::new(); 114 | 115 | params.add_raw("key=value".to_string()); 116 | params.add_raw("another=value".to_string()); 117 | 118 | assert_eq!("key=value&another=value", params.to_string()); 119 | } 120 | 121 | #[test] 122 | fn it_should_add_single_keys_correctly() { 123 | let mut params = QueryParamsStore::new(); 124 | 125 | params.add_raw("key".to_string()); 126 | params.add_raw("another".to_string()); 127 | 128 | assert_eq!("key&another", params.to_string()); 129 | } 130 | 131 | #[test] 132 | fn it_should_add_query_param_strings_correctly() { 133 | let mut params = QueryParamsStore::new(); 134 | 135 | params.add_raw("key=value&another=value".to_string()); 136 | params.add_raw("more=value".to_string()); 137 | 138 | assert_eq!("key=value&another=value&more=value", params.to_string()); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /src/internals/request_path_formatter.rs: -------------------------------------------------------------------------------- 1 | use http::Method; 2 | use std::fmt; 3 | 4 | use crate::internals::QueryParamsStore; 5 | 6 | #[derive(Debug, Clone, PartialEq)] 7 | pub struct RequestPathFormatter<'a> { 8 | method: &'a Method, 9 | 10 | /// This is the path that the user requested. 11 | user_requested_path: &'a str, 12 | query_params: Option<&'a QueryParamsStore>, 13 | } 14 | 15 | impl<'a> RequestPathFormatter<'a> { 16 | pub fn new( 17 | method: &'a Method, 18 | user_requested_path: &'a str, 19 | query_params: Option<&'a QueryParamsStore>, 20 | ) -> Self { 21 | Self { 22 | method, 23 | user_requested_path, 24 | query_params, 25 | } 26 | } 27 | } 28 | 29 | impl fmt::Display for RequestPathFormatter<'_> { 30 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 31 | let method = &self.method; 32 | let user_requested_path = &self.user_requested_path; 33 | 34 | match self.query_params { 35 | None => { 36 | write!(f, "{method} {user_requested_path}") 37 | } 38 | Some(query_params) => { 39 | if query_params.is_empty() { 40 | write!(f, "{method} {user_requested_path}") 41 | } else { 42 | write!(f, "{method} {user_requested_path}?{query_params}") 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | #[cfg(test)] 50 | mod test_fmt { 51 | use super::*; 52 | 53 | #[test] 54 | fn it_should_format_with_path_given() { 55 | let query_params = QueryParamsStore::new(); 56 | let debug = RequestPathFormatter::new(&Method::GET, &"/donkeys", Some(&query_params)); 57 | let output = format!("{}", debug); 58 | 59 | assert_eq!(output, "GET /donkeys"); 60 | } 61 | 62 | #[test] 63 | fn it_should_format_with_path_given_and_no_query_params() { 64 | let debug = RequestPathFormatter::new(&Method::GET, &"/donkeys", None); 65 | let output = format!("{}", debug); 66 | 67 | assert_eq!(output, "GET /donkeys"); 68 | } 69 | 70 | #[test] 71 | fn it_should_format_with_path_given_and_query_params() { 72 | let mut query_params = QueryParamsStore::new(); 73 | query_params.add_raw("value=123".to_string()); 74 | query_params.add_raw("another-value".to_string()); 75 | 76 | let debug = RequestPathFormatter::new(&Method::GET, &"/donkeys", Some(&query_params)); 77 | let output = format!("{}", debug); 78 | 79 | assert_eq!(output, "GET /donkeys?value=123&another-value"); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/internals/starting_tcp_setup.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use anyhow::Result; 3 | use reserve_port::ReservedPort; 4 | use std::net::IpAddr; 5 | use std::net::Ipv4Addr; 6 | use std::net::SocketAddr; 7 | use std::net::TcpListener as StdTcpListener; 8 | use tokio::net::TcpListener as TokioTcpListener; 9 | 10 | pub const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); 11 | 12 | #[derive(Debug)] 13 | pub struct StartingTcpSetup { 14 | pub maybe_reserved_port: Option, 15 | pub socket_addr: SocketAddr, 16 | pub tcp_listener: TokioTcpListener, 17 | } 18 | 19 | impl StartingTcpSetup { 20 | pub fn new(maybe_ip: Option, maybe_port: Option) -> Result { 21 | let ip = maybe_ip.unwrap_or(DEFAULT_IP_ADDRESS); 22 | 23 | maybe_port 24 | .map(|port| Self::new_with_port(ip, port)) 25 | .unwrap_or_else(|| Self::new_without_port(ip)) 26 | } 27 | 28 | fn new_with_port(ip: IpAddr, port: u16) -> Result { 29 | ReservedPort::reserve_port(port)?; 30 | let socket_addr = SocketAddr::new(ip, port); 31 | let std_tcp_listener = StdTcpListener::bind(socket_addr) 32 | .context("Failed to create TCPListener for TestServer")?; 33 | std_tcp_listener.set_nonblocking(true)?; 34 | let tokio_tcp_listener = TokioTcpListener::from_std(std_tcp_listener)?; 35 | 36 | Ok(Self { 37 | maybe_reserved_port: None, 38 | socket_addr, 39 | tcp_listener: tokio_tcp_listener, 40 | }) 41 | } 42 | 43 | fn new_without_port(ip: IpAddr) -> Result { 44 | let (reserved_port, std_tcp_listener) = ReservedPort::random_with_tcp(ip)?; 45 | let socket_addr = SocketAddr::new(ip, reserved_port.port()); 46 | std_tcp_listener.set_nonblocking(true)?; 47 | let tokio_tcp_listener = TokioTcpListener::from_std(std_tcp_listener)?; 48 | 49 | Ok(Self { 50 | maybe_reserved_port: Some(reserved_port), 51 | socket_addr, 52 | tcp_listener: tokio_tcp_listener, 53 | }) 54 | } 55 | } 56 | 57 | #[cfg(test)] 58 | mod test_new { 59 | use super::*; 60 | use regex::Regex; 61 | use std::net::Ipv4Addr; 62 | 63 | #[tokio::test] 64 | async fn it_should_create_default_ip_with_random_port_when_none() { 65 | let ip = None; 66 | let port = None; 67 | 68 | let setup = StartingTcpSetup::new(ip, port).unwrap(); 69 | let addr = format!("{}", setup.socket_addr); 70 | 71 | let regex = Regex::new("^127\\.0\\.0\\.1:[0-9]+$").unwrap(); 72 | let is_match = regex.is_match(&addr); 73 | assert!(is_match); 74 | } 75 | 76 | #[tokio::test] 77 | async fn it_should_create_ip_with_random_port_when_ip_given() { 78 | let ip = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); 79 | let port = None; 80 | 81 | let setup = StartingTcpSetup::new(ip, port).unwrap(); 82 | let addr = format!("{}", setup.socket_addr); 83 | 84 | let regex = Regex::new("^127\\.0\\.0\\.1:[0-9]+$").unwrap(); 85 | let is_match = regex.is_match(&addr); 86 | assert!(is_match); 87 | } 88 | 89 | #[tokio::test] 90 | async fn it_should_create_default_ip_with_port_when_port_given() { 91 | let ip = None; 92 | let port = Some(8123); 93 | 94 | let setup = StartingTcpSetup::new(ip, port).unwrap(); 95 | let addr = format!("{}", setup.socket_addr); 96 | 97 | assert_eq!(addr, "127.0.0.1:8123"); 98 | } 99 | 100 | #[tokio::test] 101 | async fn it_should_create_ip_port_given_when_both_given() { 102 | let ip = Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))); 103 | let port = Some(8124); 104 | 105 | let setup = StartingTcpSetup::new(ip, port).unwrap(); 106 | let addr = format!("{}", setup.socket_addr); 107 | 108 | assert_eq!(addr, "127.0.0.1:8124"); 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /src/internals/status_code_formatter.rs: -------------------------------------------------------------------------------- 1 | use http::StatusCode; 2 | use std::fmt; 3 | 4 | #[derive(Debug, Copy, Clone, PartialEq)] 5 | pub struct StatusCodeFormatter(pub StatusCode); 6 | 7 | impl fmt::Display for StatusCodeFormatter { 8 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 9 | let code = self.0.as_u16(); 10 | let reason = self.0.canonical_reason().unwrap_or("unknown status code"); 11 | 12 | write!(f, "{code} ({reason})") 13 | } 14 | } 15 | 16 | #[cfg(test)] 17 | mod test_fmt { 18 | use super::*; 19 | 20 | #[test] 21 | fn it_should_format_with_reason_where_available() { 22 | let status_code = StatusCode::UNAUTHORIZED; 23 | let debug = StatusCodeFormatter(status_code); 24 | let output = format!("{}", debug); 25 | 26 | assert_eq!(output, "401 (Unauthorized)"); 27 | } 28 | 29 | #[test] 30 | fn it_should_provide_only_number_where_reason_is_unavailable() { 31 | let status_code = StatusCode::from_u16(218).unwrap(); // Unofficial Apache status code. 32 | let debug = StatusCodeFormatter(status_code); 33 | let output = format!("{}", debug); 34 | 35 | assert_eq!(output, "218 (unknown status code)"); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/internals/transport_layer/http_transport_layer.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::body::Body; 3 | use http::Request; 4 | use http::Response; 5 | use hyper_util::client::legacy::Client; 6 | use reserve_port::ReservedPort; 7 | use std::future::Future; 8 | use std::pin::Pin; 9 | use url::Url; 10 | 11 | use crate::transport_layer::TransportLayer; 12 | use crate::transport_layer::TransportLayerType; 13 | use crate::util::ServeHandle; 14 | 15 | #[derive(Debug)] 16 | pub struct HttpTransportLayer { 17 | #[allow(dead_code)] 18 | serve_handle: ServeHandle, 19 | 20 | #[allow(dead_code)] 21 | maybe_reserved_port: Option, 22 | 23 | url: Url, 24 | } 25 | 26 | impl HttpTransportLayer { 27 | pub(crate) fn new( 28 | serve_handle: ServeHandle, 29 | maybe_reserved_port: Option, 30 | url: Url, 31 | ) -> Self { 32 | Self { 33 | serve_handle, 34 | maybe_reserved_port, 35 | url, 36 | } 37 | } 38 | } 39 | 40 | impl TransportLayer for HttpTransportLayer { 41 | fn send<'a>( 42 | &'a self, 43 | request: Request, 44 | ) -> Pin>>>> { 45 | Box::pin(async { 46 | let client = Client::builder(hyper_util::rt::TokioExecutor::new()).build_http(); 47 | let hyper_response = client.request(request).await?; 48 | 49 | let (parts, response_body) = hyper_response.into_parts(); 50 | let returned_response: Response = 51 | Response::from_parts(parts, Body::new(response_body)); 52 | 53 | Ok(returned_response) 54 | }) 55 | } 56 | 57 | fn url(&self) -> Option<&Url> { 58 | Some(&self.url) 59 | } 60 | 61 | fn transport_layer_type(&self) -> TransportLayerType { 62 | TransportLayerType::Http 63 | } 64 | 65 | fn is_running(&self) -> bool { 66 | !self.serve_handle.is_finished() 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/internals/transport_layer/mock_transport_layer.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Error as AnyhowError; 2 | use anyhow::Result; 3 | use axum::body::Body; 4 | use axum::response::Response as AxumResponse; 5 | use bytes::Bytes; 6 | use http::Request; 7 | use http::Response; 8 | use std::fmt::Debug; 9 | use std::future::Future; 10 | use std::pin::Pin; 11 | use tower::util::ServiceExt; 12 | use tower::Service; 13 | 14 | use crate::transport_layer::TransportLayer; 15 | use crate::transport_layer::TransportLayerType; 16 | 17 | pub struct MockTransportLayer { 18 | service: S, 19 | } 20 | 21 | impl MockTransportLayer 22 | where 23 | S: Service, Response = RouterService> + Clone + Send + Sync, 24 | AnyhowError: From, 25 | S::Future: Send, 26 | RouterService: Service, Response = AxumResponse>, 27 | { 28 | pub(crate) fn new(service: S) -> Self { 29 | Self { service } 30 | } 31 | } 32 | 33 | impl TransportLayer for MockTransportLayer 34 | where 35 | S: Service, Response = RouterService> + Clone + Send + Sync + 'static, 36 | AnyhowError: From, 37 | S::Future: Send + Sync, 38 | RouterService: Service, Response = AxumResponse>, 39 | AnyhowError: From, 40 | { 41 | fn send<'a>( 42 | &'a self, 43 | request: Request, 44 | ) -> Pin>>>> { 45 | Box::pin(async { 46 | let body: Body = Bytes::new().into(); 47 | let empty_request = Request::builder() 48 | .body(body) 49 | .expect("should build empty request"); 50 | 51 | let service = self.service.clone(); 52 | let router = service.oneshot(empty_request).await?; 53 | 54 | let response = router.oneshot(request).await?; 55 | Ok(response) 56 | }) 57 | } 58 | 59 | fn transport_layer_type(&self) -> TransportLayerType { 60 | TransportLayerType::Mock 61 | } 62 | 63 | /// This will always return true. 64 | #[inline(always)] 65 | fn is_running(&self) -> bool { 66 | true 67 | } 68 | } 69 | 70 | impl Debug for MockTransportLayer { 71 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 72 | write!(f, "MockTransportLayer {{ service: {{unknown}} }}") 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/internals/transport_layer/mod.rs: -------------------------------------------------------------------------------- 1 | mod http_transport_layer; 2 | pub use self::http_transport_layer::*; 3 | 4 | mod mock_transport_layer; 5 | pub use self::mock_transport_layer::*; 6 | -------------------------------------------------------------------------------- /src/internals/try_into_range_bounds.rs: -------------------------------------------------------------------------------- 1 | use std::convert::Infallible; 2 | use std::fmt::Debug; 3 | use std::ops::Range; 4 | use std::ops::RangeBounds; 5 | use std::ops::RangeFrom; 6 | use std::ops::RangeFull; 7 | use std::ops::RangeInclusive; 8 | use std::ops::RangeTo; 9 | use std::ops::RangeToInclusive; 10 | 11 | pub trait TryIntoRangeBounds { 12 | type TargetRange: RangeBounds; 13 | type Error: Debug; 14 | 15 | fn try_into_range_bounds(self) -> Result; 16 | } 17 | 18 | impl TryIntoRangeBounds for Range 19 | where 20 | A: TryInto, 21 | A::Error: Debug, 22 | { 23 | type TargetRange = Range; 24 | type Error = >::Error; 25 | 26 | fn try_into_range_bounds(self) -> Result { 27 | Ok(self.start.try_into()?..self.end.try_into()?) 28 | } 29 | } 30 | 31 | impl TryIntoRangeBounds for RangeFrom 32 | where 33 | A: TryInto, 34 | A::Error: Debug, 35 | { 36 | type TargetRange = RangeFrom; 37 | type Error = >::Error; 38 | 39 | fn try_into_range_bounds(self) -> Result { 40 | Ok(self.start.try_into()?..) 41 | } 42 | } 43 | 44 | impl TryIntoRangeBounds for RangeTo 45 | where 46 | A: TryInto, 47 | A::Error: Debug, 48 | { 49 | type TargetRange = RangeTo; 50 | type Error = >::Error; 51 | 52 | fn try_into_range_bounds(self) -> Result { 53 | Ok(..self.end.try_into()?) 54 | } 55 | } 56 | 57 | impl TryIntoRangeBounds for RangeInclusive 58 | where 59 | A: TryInto, 60 | A::Error: Debug, 61 | { 62 | type TargetRange = RangeInclusive; 63 | type Error = >::Error; 64 | 65 | fn try_into_range_bounds(self) -> Result { 66 | let (start, end) = self.into_inner(); 67 | Ok(start.try_into()?..=end.try_into()?) 68 | } 69 | } 70 | 71 | impl TryIntoRangeBounds for RangeToInclusive 72 | where 73 | A: TryInto, 74 | A::Error: Debug, 75 | { 76 | type TargetRange = RangeToInclusive; 77 | type Error = >::Error; 78 | 79 | fn try_into_range_bounds(self) -> Result { 80 | Ok(..=self.end.try_into()?) 81 | } 82 | } 83 | 84 | impl TryIntoRangeBounds for RangeFull { 85 | type TargetRange = RangeFull; 86 | type Error = Infallible; 87 | 88 | fn try_into_range_bounds(self) -> Result { 89 | Ok(self) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/internals/websockets/mod.rs: -------------------------------------------------------------------------------- 1 | mod test_response_websocket; 2 | pub use self::test_response_websocket::*; 3 | 4 | mod ws_key_generator; 5 | pub use self::ws_key_generator::*; 6 | -------------------------------------------------------------------------------- /src/internals/websockets/test_response_websocket.rs: -------------------------------------------------------------------------------- 1 | use hyper::upgrade::OnUpgrade; 2 | 3 | use crate::transport_layer::TransportLayerType; 4 | 5 | #[derive(Debug, Clone)] 6 | pub struct TestResponseWebSocket { 7 | pub maybe_on_upgrade: Option, 8 | pub transport_type: TransportLayerType, 9 | } 10 | -------------------------------------------------------------------------------- /src/internals/websockets/ws_key_generator.rs: -------------------------------------------------------------------------------- 1 | use base64::engine::general_purpose::STANDARD; 2 | use base64::Engine; 3 | use uuid::Uuid; 4 | 5 | /// Generates a random key for use, that is base 64 encoded for use over HTTP. 6 | pub fn generate_ws_key() -> String { 7 | STANDARD.encode(Uuid::new_v4().as_bytes()) 8 | } 9 | -------------------------------------------------------------------------------- /src/internals/with_this_mut.rs: -------------------------------------------------------------------------------- 1 | use anyhow::anyhow; 2 | use anyhow::Result; 3 | use std::sync::Arc; 4 | use std::sync::Mutex; 5 | 6 | pub fn with_this_mut(this: &Arc>, name: &str, some_action: F) -> Result 7 | where 8 | F: FnOnce(&mut T) -> R, 9 | { 10 | let mut this_locked = this.lock().map_err(|err| { 11 | anyhow!( 12 | "Failed to lock InternalTestServer for `{}`, {:?}", 13 | name, 14 | err, 15 | ) 16 | })?; 17 | 18 | let result = some_action(&mut this_locked); 19 | 20 | Ok(result) 21 | } 22 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! Axum Test is a library for writing tests for web servers written using Axum: 3 | //! 4 | //! * You create a [`TestServer`] within a test, 5 | //! * use that to build [`TestRequest`] against your application, 6 | //! * receive back a [`TestResponse`], 7 | //! * then assert the response is how you expect. 8 | //! 9 | //! It includes built in support for serializing and deserializing request and response bodies using Serde, 10 | //! support for cookies and headers, and other common bits you would expect. 11 | //! 12 | //! `TestServer` will pass http requests directly to the handler, 13 | //! or can be run on a random IP / Port address. 14 | //! 15 | //! ## Getting Started 16 | //! 17 | //! Create a [`TestServer`] running your Axum [`Router`](::axum::Router): 18 | //! 19 | //! ```rust 20 | //! # async fn test() -> Result<(), Box> { 21 | //! # 22 | //! use axum::Router; 23 | //! use axum::extract::Json; 24 | //! use axum::routing::put; 25 | //! use axum_test::TestServer; 26 | //! use serde_json::json; 27 | //! use serde_json::Value; 28 | //! 29 | //! async fn route_put_user(Json(user): Json) -> () { 30 | //! // todo 31 | //! } 32 | //! 33 | //! let my_app = Router::new() 34 | //! .route("/users", put(route_put_user)); 35 | //! 36 | //! let server = TestServer::new(my_app)?; 37 | //! # 38 | //! # Ok(()) 39 | //! # } 40 | //! ``` 41 | //! 42 | //! Then make requests against it: 43 | //! 44 | //! ```rust 45 | //! # async fn test() -> Result<(), Box> { 46 | //! # 47 | //! # use axum::Router; 48 | //! # use axum::extract::Json; 49 | //! # use axum::routing::put; 50 | //! # use axum_test::TestServer; 51 | //! # use serde_json::json; 52 | //! # use serde_json::Value; 53 | //! # 54 | //! # async fn put_user(Json(user): Json) -> () {} 55 | //! # 56 | //! # let my_app = Router::new() 57 | //! # .route("/users", put(put_user)); 58 | //! # 59 | //! # let server = TestServer::new(my_app)?; 60 | //! # 61 | //! let response = server.put("/users") 62 | //! .json(&json!({ 63 | //! "username": "Terrance Pencilworth", 64 | //! })) 65 | //! .await; 66 | //! # 67 | //! # Ok(()) 68 | //! # } 69 | //! ``` 70 | //! 71 | 72 | #![allow(clippy::module_inception)] 73 | #![allow(clippy::derivable_impls)] 74 | #![allow(clippy::manual_range_contains)] 75 | #![forbid(unsafe_code)] 76 | #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] 77 | 78 | pub(crate) mod internals; 79 | 80 | pub mod multipart; 81 | 82 | pub mod transport_layer; 83 | pub mod util; 84 | 85 | mod test_request; 86 | pub use self::test_request::*; 87 | 88 | mod test_response; 89 | pub use self::test_response::*; 90 | 91 | mod test_server_builder; 92 | pub use self::test_server_builder::*; 93 | 94 | mod test_server_config; 95 | pub use self::test_server_config::*; 96 | 97 | mod test_server; 98 | pub use self::test_server::*; 99 | 100 | #[cfg(feature = "ws")] 101 | mod test_web_socket; 102 | #[cfg(feature = "ws")] 103 | pub use self::test_web_socket::*; 104 | #[cfg(feature = "ws")] 105 | pub use tokio_tungstenite::tungstenite::Message as WsMessage; 106 | 107 | mod transport; 108 | pub use self::transport::*; 109 | 110 | pub mod expect_json; 111 | 112 | pub use http; 113 | 114 | #[cfg(test)] 115 | mod integrated_test_cookie_saving { 116 | use super::*; 117 | 118 | use axum::extract::Request; 119 | use axum::routing::get; 120 | use axum::routing::post; 121 | use axum::routing::put; 122 | use axum::Router; 123 | use axum_extra::extract::cookie::Cookie as AxumCookie; 124 | use axum_extra::extract::cookie::CookieJar; 125 | use cookie::time::OffsetDateTime; 126 | use cookie::Cookie; 127 | use http_body_util::BodyExt; 128 | use std::time::Duration; 129 | 130 | const TEST_COOKIE_NAME: &'static str = &"test-cookie"; 131 | 132 | async fn get_cookie(cookies: CookieJar) -> (CookieJar, String) { 133 | let cookie = cookies.get(&TEST_COOKIE_NAME); 134 | let cookie_value = cookie 135 | .map(|c| c.value().to_string()) 136 | .unwrap_or_else(|| "cookie-not-found".to_string()); 137 | 138 | (cookies, cookie_value) 139 | } 140 | 141 | async fn put_cookie(mut cookies: CookieJar, request: Request) -> (CookieJar, &'static str) { 142 | let body_bytes = request 143 | .into_body() 144 | .collect() 145 | .await 146 | .expect("Should extract the body") 147 | .to_bytes(); 148 | let body_text: String = String::from_utf8_lossy(&body_bytes).to_string(); 149 | let cookie = AxumCookie::new(TEST_COOKIE_NAME, body_text); 150 | cookies = cookies.add(cookie); 151 | 152 | (cookies, &"done") 153 | } 154 | 155 | async fn post_expire_cookie(mut cookies: CookieJar) -> (CookieJar, &'static str) { 156 | let mut cookie = AxumCookie::new(TEST_COOKIE_NAME, "expired".to_string()); 157 | let expired_time = OffsetDateTime::now_utc() - Duration::from_secs(1); 158 | cookie.set_expires(expired_time); 159 | cookies = cookies.add(cookie); 160 | 161 | (cookies, &"done") 162 | } 163 | 164 | fn new_test_router() -> Router { 165 | Router::new() 166 | .route("/cookie", put(put_cookie)) 167 | .route("/cookie", get(get_cookie)) 168 | .route("/expire", post(post_expire_cookie)) 169 | } 170 | 171 | #[tokio::test] 172 | async fn it_should_not_pass_cookies_created_back_up_to_server_by_default() { 173 | // Run the server. 174 | let server = TestServer::new(new_test_router()).expect("Should create test server"); 175 | 176 | // Create a cookie. 177 | server.put(&"/cookie").text(&"new-cookie").await; 178 | 179 | // Check it comes back. 180 | let response_text = server.get(&"/cookie").await.text(); 181 | 182 | assert_eq!(response_text, "cookie-not-found"); 183 | } 184 | 185 | #[tokio::test] 186 | async fn it_should_not_pass_cookies_created_back_up_to_server_when_turned_off() { 187 | // Run the server. 188 | let server = TestServer::builder() 189 | .do_not_save_cookies() 190 | .build(new_test_router()) 191 | .expect("Should create test server"); 192 | 193 | // Create a cookie. 194 | server.put(&"/cookie").text(&"new-cookie").await; 195 | 196 | // Check it comes back. 197 | let response_text = server.get(&"/cookie").await.text(); 198 | 199 | assert_eq!(response_text, "cookie-not-found"); 200 | } 201 | 202 | #[tokio::test] 203 | async fn it_should_pass_cookies_created_back_up_to_server_automatically() { 204 | // Run the server. 205 | let server = TestServer::builder() 206 | .save_cookies() 207 | .build(new_test_router()) 208 | .expect("Should create test server"); 209 | 210 | // Create a cookie. 211 | server.put(&"/cookie").text(&"cookie-found!").await; 212 | 213 | // Check it comes back. 214 | let response_text = server.get(&"/cookie").await.text(); 215 | 216 | assert_eq!(response_text, "cookie-found!"); 217 | } 218 | 219 | #[tokio::test] 220 | async fn it_should_pass_cookies_created_back_up_to_server_when_turned_on_for_request() { 221 | // Run the server. 222 | let server = TestServer::builder() 223 | .do_not_save_cookies() // it's off by default! 224 | .build(new_test_router()) 225 | .expect("Should create test server"); 226 | 227 | // Create a cookie. 228 | server 229 | .put(&"/cookie") 230 | .text(&"cookie-found!") 231 | .save_cookies() 232 | .await; 233 | 234 | // Check it comes back. 235 | let response_text = server.get(&"/cookie").await.text(); 236 | 237 | assert_eq!(response_text, "cookie-found!"); 238 | } 239 | 240 | #[tokio::test] 241 | async fn it_should_wipe_cookies_cleared_by_request() { 242 | // Run the server. 243 | let server = TestServer::builder() 244 | .do_not_save_cookies() // it's off by default! 245 | .build(new_test_router()) 246 | .expect("Should create test server"); 247 | 248 | // Create a cookie. 249 | server 250 | .put(&"/cookie") 251 | .text(&"cookie-found!") 252 | .save_cookies() 253 | .await; 254 | 255 | // Check it comes back. 256 | let response_text = server.get(&"/cookie").clear_cookies().await.text(); 257 | 258 | assert_eq!(response_text, "cookie-not-found"); 259 | } 260 | 261 | #[tokio::test] 262 | async fn it_should_wipe_cookies_cleared_by_test_server() { 263 | // Run the server. 264 | let mut server = TestServer::builder() 265 | .do_not_save_cookies() // it's off by default! 266 | .build(new_test_router()) 267 | .expect("Should create test server"); 268 | 269 | // Create a cookie. 270 | server 271 | .put(&"/cookie") 272 | .text(&"cookie-found!") 273 | .save_cookies() 274 | .await; 275 | 276 | server.clear_cookies(); 277 | 278 | // Check it comes back. 279 | let response_text = server.get(&"/cookie").await.text(); 280 | 281 | assert_eq!(response_text, "cookie-not-found"); 282 | } 283 | 284 | #[tokio::test] 285 | async fn it_should_send_cookies_added_to_request() { 286 | // Run the server. 287 | let server = TestServer::builder() 288 | .do_not_save_cookies() // it's off by default! 289 | .build(new_test_router()) 290 | .expect("Should create test server"); 291 | 292 | // Check it comes back. 293 | let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie"); 294 | 295 | let response_text = server.get(&"/cookie").add_cookie(cookie).await.text(); 296 | 297 | assert_eq!(response_text, "my-custom-cookie"); 298 | } 299 | 300 | #[tokio::test] 301 | async fn it_should_send_cookies_added_to_test_server() { 302 | // Run the server. 303 | let mut server = TestServer::builder() 304 | .do_not_save_cookies() // it's off by default! 305 | .build(new_test_router()) 306 | .expect("Should create test server"); 307 | 308 | // Check it comes back. 309 | let cookie = Cookie::new(TEST_COOKIE_NAME, "my-custom-cookie"); 310 | server.add_cookie(cookie); 311 | 312 | let response_text = server.get(&"/cookie").await.text(); 313 | 314 | assert_eq!(response_text, "my-custom-cookie"); 315 | } 316 | 317 | #[tokio::test] 318 | async fn it_should_remove_expired_cookies_from_later_requests() { 319 | // Run the server. 320 | let mut server = TestServer::new(new_test_router()).expect("Should create test server"); 321 | server.save_cookies(); 322 | 323 | // Create a cookie. 324 | server.put(&"/cookie").text(&"cookie-found!").await; 325 | 326 | // Check it comes back. 327 | let response_text = server.get(&"/cookie").await.text(); 328 | assert_eq!(response_text, "cookie-found!"); 329 | 330 | server.post(&"/expire").await; 331 | 332 | // Then expire the cookie. 333 | let found_cookie = server.post(&"/expire").await.maybe_cookie(TEST_COOKIE_NAME); 334 | assert!(found_cookie.is_some()); 335 | 336 | // It's no longer found 337 | let response_text = server.get(&"/cookie").await.text(); 338 | assert_eq!(response_text, "cookie-not-found"); 339 | } 340 | } 341 | 342 | #[cfg(feature = "typed-routing")] 343 | #[cfg(test)] 344 | mod integrated_test_typed_routing_and_query { 345 | use super::*; 346 | 347 | use axum::extract::Query; 348 | use axum::Router; 349 | use axum_extra::routing::RouterExt; 350 | use axum_extra::routing::TypedPath; 351 | use serde::Deserialize; 352 | use serde::Serialize; 353 | 354 | #[derive(TypedPath, Deserialize)] 355 | #[typed_path("/path-query/{id}")] 356 | struct TestingPathQuery { 357 | id: u32, 358 | } 359 | 360 | #[derive(Serialize, Deserialize)] 361 | struct QueryParams { 362 | param: String, 363 | other: Option, 364 | } 365 | 366 | async fn route_get_with_param( 367 | TestingPathQuery { id }: TestingPathQuery, 368 | Query(params): Query, 369 | ) -> String { 370 | let query = params.param; 371 | if let Some(other) = params.other { 372 | format!("get {id}, {query}&{other}") 373 | } else { 374 | format!("get {id}, {query}") 375 | } 376 | } 377 | 378 | fn new_app() -> Router { 379 | Router::new().typed_get(route_get_with_param) 380 | } 381 | 382 | #[tokio::test] 383 | async fn it_should_send_typed_get_with_query_params() { 384 | let server = TestServer::new(new_app()).unwrap(); 385 | let path = TestingPathQuery { id: 123 }.with_query_params(QueryParams { 386 | param: "with-typed-query".to_string(), 387 | other: None, 388 | }); 389 | 390 | server 391 | .typed_get(&path) 392 | .expect_success() 393 | .await 394 | .assert_text("get 123, with-typed-query"); 395 | } 396 | 397 | #[tokio::test] 398 | async fn it_should_send_typed_get_with_added_query_param() { 399 | let server = TestServer::new(new_app()).unwrap(); 400 | let path = TestingPathQuery { id: 123 }; 401 | 402 | server 403 | .typed_get(&path) 404 | .add_query_param("param", "with-added-query") 405 | .expect_success() 406 | .await 407 | .assert_text("get 123, with-added-query"); 408 | } 409 | 410 | #[tokio::test] 411 | async fn it_should_send_both_typed_and_added_query() { 412 | let server = TestServer::new(new_app()).unwrap(); 413 | let path = TestingPathQuery { id: 123 }.with_query_params(QueryParams { 414 | param: "with-typed-query".to_string(), 415 | other: None, 416 | }); 417 | 418 | server 419 | .typed_get(&path) 420 | .add_query_param("other", "with-added-query") 421 | .expect_success() 422 | .await 423 | .assert_text("get 123, with-typed-query&with-added-query"); 424 | } 425 | 426 | #[tokio::test] 427 | async fn it_should_send_replaced_query_when_cleared() { 428 | let server = TestServer::new(new_app()).unwrap(); 429 | let path = TestingPathQuery { id: 123 }.with_query_params(QueryParams { 430 | param: "with-typed-query".to_string(), 431 | other: Some("with-typed-other".to_string()), 432 | }); 433 | 434 | server 435 | .typed_get(&path) 436 | .clear_query_params() 437 | .add_query_param("param", "with-added-query") 438 | .expect_success() 439 | .await 440 | .assert_text("get 123, with-added-query"); 441 | } 442 | } 443 | -------------------------------------------------------------------------------- /src/multipart/mod.rs: -------------------------------------------------------------------------------- 1 | //! 2 | //! This supplies the building blocks for sending multipart forms using 3 | //! [`TestRequest::multipart()`](crate::TestRequest::multipart()). 4 | //! 5 | //! The request body can be built using [`MultipartForm`] and [`Part`]. 6 | //! 7 | //! # Simple example 8 | //! 9 | //! ```rust 10 | //! # async fn test() -> Result<(), Box> { 11 | //! # 12 | //! use axum::Router; 13 | //! use axum_test::TestServer; 14 | //! use axum_test::multipart::MultipartForm; 15 | //! 16 | //! let app = Router::new(); 17 | //! let server = TestServer::new(app)?; 18 | //! 19 | //! let multipart_form = MultipartForm::new() 20 | //! .add_text("name", "Joe") 21 | //! .add_text("animals", "foxes"); 22 | //! 23 | //! let response = server.post(&"/my-form") 24 | //! .multipart(multipart_form) 25 | //! .await; 26 | //! # 27 | //! # Ok(()) } 28 | //! ``` 29 | //! 30 | //! # Sending byte parts 31 | //! 32 | //! ```rust 33 | //! # async fn test() -> Result<(), Box> { 34 | //! # 35 | //! use axum::Router; 36 | //! use axum_test::TestServer; 37 | //! use axum_test::multipart::MultipartForm; 38 | //! use axum_test::multipart::Part; 39 | //! 40 | //! let app = Router::new(); 41 | //! let server = TestServer::new(app)?; 42 | //! 43 | //! let image_bytes = include_bytes!("../../README.md"); 44 | //! let image_part = Part::bytes(image_bytes.as_slice()) 45 | //! .file_name(&"README.md") 46 | //! .mime_type(&"text/markdown"); 47 | //! 48 | //! let multipart_form = MultipartForm::new() 49 | //! .add_part("file", image_part); 50 | //! 51 | //! let response = server.post(&"/my-form") 52 | //! .multipart(multipart_form) 53 | //! .await; 54 | //! # 55 | //! # Ok(()) } 56 | //! ``` 57 | //! 58 | 59 | mod multipart_form; 60 | pub use self::multipart_form::*; 61 | 62 | mod part; 63 | pub use self::part::*; 64 | -------------------------------------------------------------------------------- /src/multipart/multipart_form.rs: -------------------------------------------------------------------------------- 1 | use crate::multipart::Part; 2 | use axum::body::Body as AxumBody; 3 | use rust_multipart_rfc7578_2::client::multipart::Body as CommonMultipartBody; 4 | use rust_multipart_rfc7578_2::client::multipart::Form; 5 | use std::fmt::Display; 6 | use std::io::Cursor; 7 | 8 | #[derive(Debug)] 9 | pub struct MultipartForm { 10 | inner: Form<'static>, 11 | } 12 | 13 | impl MultipartForm { 14 | pub fn new() -> Self { 15 | Default::default() 16 | } 17 | 18 | /// Creates a text part, and adds it to be sent. 19 | pub fn add_text(mut self, name: N, text: T) -> Self 20 | where 21 | N: Display, 22 | T: ToString, 23 | { 24 | self.inner.add_text(name, text.to_string()); 25 | self 26 | } 27 | 28 | /// Adds a new section to this multipart form to be sent. 29 | /// 30 | /// See [`Part`](crate::multipart::Part). 31 | pub fn add_part(mut self, name: N, part: Part) -> Self 32 | where 33 | N: Display, 34 | { 35 | let reader = Cursor::new(part.bytes); 36 | self.inner.add_reader_2( 37 | name, 38 | reader, 39 | part.file_name, 40 | Some(part.mime_type), 41 | part.headers, 42 | ); 43 | 44 | self 45 | } 46 | 47 | /// Returns the content type this form will use when it is sent. 48 | pub fn content_type(&self) -> String { 49 | self.inner.content_type() 50 | } 51 | } 52 | 53 | impl Default for MultipartForm { 54 | fn default() -> Self { 55 | Self { 56 | inner: Default::default(), 57 | } 58 | } 59 | } 60 | 61 | impl From for AxumBody { 62 | fn from(multipart: MultipartForm) -> Self { 63 | let inner_body: CommonMultipartBody = multipart.inner.into(); 64 | AxumBody::from_stream(inner_body) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/multipart/part.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use bytes::Bytes; 3 | use http::HeaderName; 4 | use http::HeaderValue; 5 | use mime::Mime; 6 | use std::fmt::Debug; 7 | use std::fmt::Display; 8 | 9 | /// 10 | /// For creating a section of a MultipartForm. 11 | /// 12 | /// Use [`Part::text()`](crate::multipart::Part::text()) and [`Part::bytes()`](crate::multipart::Part::bytes()) for creating new instances. 13 | /// Then attach them to a `MultipartForm` using [`MultipartForm::add_part()`](crate::multipart::MultipartForm::add_part()). 14 | /// 15 | #[derive(Debug, Clone)] 16 | pub struct Part { 17 | pub(crate) bytes: Bytes, 18 | pub(crate) file_name: Option, 19 | pub(crate) mime_type: Mime, 20 | pub(crate) headers: Vec<(HeaderName, HeaderValue)>, 21 | } 22 | 23 | impl Part { 24 | /// Creates a new part of a multipart form, that will send text. 25 | /// 26 | /// The default mime type for this part will be `text/plain`, 27 | pub fn text(text: T) -> Self 28 | where 29 | T: Display, 30 | { 31 | let bytes = text.to_string().into_bytes().into(); 32 | 33 | Self::new(bytes, mime::TEXT_PLAIN) 34 | } 35 | 36 | /// Creates a new part of a multipart form, that will upload bytes. 37 | /// 38 | /// The default mime type for this part will be `application/octet-stream`, 39 | pub fn bytes(bytes: B) -> Self 40 | where 41 | B: Into, 42 | { 43 | Self::new(bytes.into(), mime::APPLICATION_OCTET_STREAM) 44 | } 45 | 46 | fn new(bytes: Bytes, mime_type: Mime) -> Self { 47 | Self { 48 | bytes, 49 | file_name: None, 50 | mime_type, 51 | headers: Default::default(), 52 | } 53 | } 54 | 55 | /// Sets the file name for this part of a multipart form. 56 | /// 57 | /// By default there is no filename. This will set one. 58 | pub fn file_name(mut self, file_name: T) -> Self 59 | where 60 | T: Display, 61 | { 62 | self.file_name = Some(file_name.to_string()); 63 | self 64 | } 65 | 66 | /// Sets the mime type for this part of a multipart form. 67 | /// 68 | /// The default mime type is `text/plain` or `application/octet-stream`, 69 | /// depending on how this instance was created. 70 | /// This function will replace that. 71 | pub fn mime_type(mut self, mime_type: M) -> Self 72 | where 73 | M: AsRef, 74 | { 75 | let raw_mime_type = mime_type.as_ref(); 76 | let parsed_mime_type = raw_mime_type 77 | .parse() 78 | .with_context(|| format!("Failed to parse '{raw_mime_type}' as a Mime type")) 79 | .unwrap(); 80 | 81 | self.mime_type = parsed_mime_type; 82 | 83 | self 84 | } 85 | 86 | /// Adds a header to be sent with the Part of this Multiform. 87 | /// 88 | /// ```rust 89 | /// # async fn test() -> Result<(), Box> { 90 | /// # 91 | /// use axum::Router; 92 | /// use axum_test::TestServer; 93 | /// use axum_test::multipart::MultipartForm; 94 | /// use axum_test::multipart::Part; 95 | /// 96 | /// let app = Router::new(); 97 | /// let server = TestServer::new(app)?; 98 | /// 99 | /// let readme_bytes = include_bytes!("../../README.md"); 100 | /// let readme_part = Part::bytes(readme_bytes.as_slice()) 101 | /// .file_name(&"README.md") 102 | /// // Add a header to the Part 103 | /// .add_header("x-text-category", "readme"); 104 | /// 105 | /// let multipart_form = MultipartForm::new() 106 | /// .add_part("file", readme_part); 107 | /// 108 | /// let response = server.post(&"/my-form") 109 | /// .multipart(multipart_form) 110 | /// .await; 111 | /// # 112 | /// # Ok(()) } 113 | /// ``` 114 | /// 115 | pub fn add_header(mut self, name: N, value: V) -> Self 116 | where 117 | N: TryInto, 118 | N::Error: Debug, 119 | V: TryInto, 120 | V::Error: Debug, 121 | { 122 | let header_name: HeaderName = name 123 | .try_into() 124 | .expect("Failed to convert header name to HeaderName"); 125 | let header_value: HeaderValue = value 126 | .try_into() 127 | .expect("Failed to convert header vlue to HeaderValue"); 128 | 129 | self.headers.push((header_name, header_value)); 130 | self 131 | } 132 | } 133 | 134 | #[cfg(test)] 135 | mod test_text { 136 | use super::*; 137 | 138 | #[test] 139 | fn it_should_contain_text_given() { 140 | let part = Part::text("some_text"); 141 | 142 | let output = String::from_utf8_lossy(&part.bytes); 143 | assert_eq!(output, "some_text"); 144 | } 145 | 146 | #[test] 147 | fn it_should_use_mime_type_text() { 148 | let part = Part::text("some_text"); 149 | assert_eq!(part.mime_type, mime::TEXT_PLAIN); 150 | } 151 | } 152 | 153 | #[cfg(test)] 154 | mod test_byes { 155 | use super::*; 156 | 157 | #[test] 158 | fn it_should_contain_bytes_given() { 159 | let bytes = "some_text".as_bytes(); 160 | let part = Part::bytes(bytes); 161 | 162 | let output = String::from_utf8_lossy(&part.bytes); 163 | assert_eq!(output, "some_text"); 164 | } 165 | 166 | #[test] 167 | fn it_should_use_mime_type_octet_stream() { 168 | let bytes = "some_text".as_bytes(); 169 | let part = Part::bytes(bytes); 170 | 171 | assert_eq!(part.mime_type, mime::APPLICATION_OCTET_STREAM); 172 | } 173 | } 174 | 175 | #[cfg(test)] 176 | mod test_file_name { 177 | use super::*; 178 | 179 | #[test] 180 | fn it_should_use_file_name_given() { 181 | let mut part = Part::text("some_text"); 182 | 183 | assert_eq!(part.file_name, None); 184 | part = part.file_name("my-text.txt"); 185 | assert_eq!(part.file_name, Some("my-text.txt".to_string())); 186 | } 187 | } 188 | 189 | #[cfg(test)] 190 | mod test_mime_type { 191 | use super::*; 192 | 193 | #[test] 194 | fn it_should_use_mime_type_set() { 195 | let mut part = Part::text("some_text"); 196 | 197 | assert_eq!(part.mime_type, mime::TEXT_PLAIN); 198 | part = part.mime_type("application/json"); 199 | assert_eq!(part.mime_type, mime::APPLICATION_JSON); 200 | } 201 | 202 | #[test] 203 | #[should_panic] 204 | fn it_should_error_if_invalid_mime_type() { 205 | let part = Part::text("some_text"); 206 | part.mime_type("🦊"); 207 | 208 | assert!(false); 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /src/test_request/test_request_config.rs: -------------------------------------------------------------------------------- 1 | use cookie::CookieJar; 2 | use http::HeaderName; 3 | use http::HeaderValue; 4 | use http::Method; 5 | use url::Url; 6 | 7 | use crate::internals::ExpectedState; 8 | use crate::internals::QueryParamsStore; 9 | 10 | #[derive(Debug, Clone)] 11 | pub struct TestRequestConfig { 12 | pub is_saving_cookies: bool, 13 | pub expected_state: ExpectedState, 14 | pub content_type: Option, 15 | pub full_request_url: Url, 16 | pub method: Method, 17 | 18 | pub cookies: CookieJar, 19 | pub query_params: QueryParamsStore, 20 | pub headers: Vec<(HeaderName, HeaderValue)>, 21 | } 22 | -------------------------------------------------------------------------------- /src/test_server/server_shared_state.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use anyhow::Result; 3 | use cookie::Cookie; 4 | use cookie::CookieJar; 5 | use http::HeaderName; 6 | use http::HeaderValue; 7 | use serde::Serialize; 8 | use std::sync::Arc; 9 | use std::sync::Mutex; 10 | 11 | use crate::internals::with_this_mut; 12 | use crate::internals::QueryParamsStore; 13 | 14 | #[derive(Debug)] 15 | pub(crate) struct ServerSharedState { 16 | scheme: Option, 17 | cookies: CookieJar, 18 | query_params: QueryParamsStore, 19 | headers: Vec<(HeaderName, HeaderValue)>, 20 | } 21 | 22 | impl ServerSharedState { 23 | pub(crate) fn new() -> Self { 24 | Self { 25 | scheme: None, 26 | cookies: CookieJar::new(), 27 | query_params: QueryParamsStore::new(), 28 | headers: Vec::new(), 29 | } 30 | } 31 | 32 | pub(crate) fn scheme(&self) -> Option<&str> { 33 | self.scheme.as_deref() 34 | } 35 | 36 | pub(crate) fn cookies(&self) -> &CookieJar { 37 | &self.cookies 38 | } 39 | 40 | pub(crate) fn query_params(&self) -> &QueryParamsStore { 41 | &self.query_params 42 | } 43 | 44 | pub(crate) fn headers(&self) -> &Vec<(HeaderName, HeaderValue)> { 45 | &self.headers 46 | } 47 | 48 | /// Adds the given cookies. 49 | /// 50 | /// They will be stored over the top of the existing cookies. 51 | pub(crate) fn add_cookies_by_header<'a, I>( 52 | this: &Arc>, 53 | cookie_headers: I, 54 | ) -> Result<()> 55 | where 56 | I: Iterator, 57 | { 58 | with_this_mut(this, "add_cookies_by_header", |this| { 59 | for cookie_header in cookie_headers { 60 | let cookie_header_str = cookie_header 61 | .to_str() 62 | .context("Reading cookie header for storing in the `TestServer`") 63 | .unwrap(); 64 | 65 | let cookie: Cookie<'static> = Cookie::parse(cookie_header_str)?.into_owned(); 66 | this.cookies.add(cookie); 67 | } 68 | 69 | Ok(()) as Result<()> 70 | })? 71 | } 72 | 73 | /// Adds the given cookies. 74 | /// 75 | /// They will be stored over the top of the existing cookies. 76 | pub(crate) fn clear_cookies(this: &Arc>) -> Result<()> { 77 | with_this_mut(this, "clear_cookies", |this| { 78 | this.cookies = CookieJar::new(); 79 | }) 80 | } 81 | 82 | /// Adds the given cookies. 83 | /// 84 | /// They will be stored over the top of the existing cookies. 85 | pub(crate) fn add_cookies(this: &Arc>, cookies: CookieJar) -> Result<()> { 86 | with_this_mut(this, "add_cookies", |this| { 87 | for cookie in cookies.iter() { 88 | this.cookies.add(cookie.to_owned()); 89 | } 90 | }) 91 | } 92 | 93 | pub(crate) fn add_cookie(this: &Arc>, cookie: Cookie) -> Result<()> { 94 | with_this_mut(this, "add_cookie", |this| { 95 | this.cookies.add(cookie.into_owned()); 96 | }) 97 | } 98 | 99 | pub(crate) fn add_query_params(this: &Arc>, query_params: V) -> Result<()> 100 | where 101 | V: Serialize, 102 | { 103 | with_this_mut(this, "add_query_params", |this| { 104 | this.query_params.add(query_params) 105 | })? 106 | } 107 | 108 | pub(crate) fn add_query_param(this: &Arc>, key: &str, value: V) -> Result<()> 109 | where 110 | V: Serialize, 111 | { 112 | with_this_mut(this, "add_query_param", |this| { 113 | this.query_params.add(&[(key, value)]) 114 | })? 115 | } 116 | 117 | pub(crate) fn add_raw_query_param(this: &Arc>, raw_value: &str) -> Result<()> { 118 | with_this_mut(this, "add_raw_query_param", |this| { 119 | this.query_params.add_raw(raw_value.to_string()) 120 | }) 121 | } 122 | 123 | pub(crate) fn clear_query_params(this: &Arc>) -> Result<()> { 124 | with_this_mut(this, "clear_query_params", |this| this.query_params.clear()) 125 | } 126 | 127 | pub(crate) fn clear_headers(this: &Arc>) -> Result<()> { 128 | with_this_mut(this, "clear_headers", |this| this.headers.clear()) 129 | } 130 | 131 | pub(crate) fn add_header( 132 | this: &Arc>, 133 | name: HeaderName, 134 | value: HeaderValue, 135 | ) -> Result<()> { 136 | with_this_mut(this, "add_header", |this| this.headers.push((name, value))) 137 | } 138 | 139 | pub(crate) fn set_scheme(this: &Arc>, scheme: String) -> Result<()> { 140 | with_this_mut(this, "set_scheme", |this| this.scheme = Some(scheme)) 141 | } 142 | 143 | pub(crate) fn set_scheme_unlocked(&mut self, scheme: String) { 144 | self.scheme = Some(scheme); 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/test_server_builder.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use std::net::IpAddr; 3 | 4 | use crate::transport_layer::IntoTransportLayer; 5 | use crate::TestServer; 6 | use crate::TestServerConfig; 7 | use crate::Transport; 8 | 9 | /// A builder for [`crate::TestServer`]. Inside is a [`crate::TestServerConfig`], 10 | /// configured by each method, and then turn into a server by [`crate::TestServerBuilder::build`]. 11 | /// 12 | /// The recommended way to make instances is to call [`crate::TestServer::builder`]. 13 | /// 14 | /// # Creating a [`crate::TestServer`] 15 | /// 16 | /// ```rust 17 | /// # async fn test() -> Result<(), Box> { 18 | /// # 19 | /// use axum::Router; 20 | /// use axum_test::TestServerBuilder; 21 | /// 22 | /// let my_app = Router::new(); 23 | /// let server = TestServerBuilder::new() 24 | /// .save_cookies() 25 | /// .default_content_type(&"application/json") 26 | /// .build(my_app)?; 27 | /// # 28 | /// # Ok(()) 29 | /// # } 30 | /// ``` 31 | /// 32 | /// # Creating a [`crate::TestServerConfig`] 33 | /// 34 | /// ```rust 35 | /// # async fn test() -> Result<(), Box> { 36 | /// # 37 | /// use axum::Router; 38 | /// use axum_test::TestServer; 39 | /// use axum_test::TestServerBuilder; 40 | /// 41 | /// let my_app = Router::new(); 42 | /// let config = TestServerBuilder::new() 43 | /// .save_cookies() 44 | /// .default_content_type(&"application/json") 45 | /// .into_config(); 46 | /// 47 | /// // Build the Test Server 48 | /// let server = TestServer::new_with_config(my_app, config)?; 49 | /// # 50 | /// # Ok(()) 51 | /// # } 52 | /// ``` 53 | /// 54 | /// These can be passed to [`crate::TestServer::new_with_config`]. 55 | /// 56 | #[derive(Debug, Clone)] 57 | pub struct TestServerBuilder { 58 | config: TestServerConfig, 59 | } 60 | 61 | impl TestServerBuilder { 62 | /// Creates a default `TestServerBuilder`. 63 | pub fn new() -> Self { 64 | Default::default() 65 | } 66 | 67 | pub fn from_config(config: TestServerConfig) -> Self { 68 | Self { config } 69 | } 70 | 71 | pub fn http_transport(self) -> Self { 72 | self.transport(Transport::HttpRandomPort) 73 | } 74 | 75 | pub fn http_transport_with_ip_port(self, ip: Option, port: Option) -> Self { 76 | self.transport(Transport::HttpIpPort { ip, port }) 77 | } 78 | 79 | pub fn mock_transport(self) -> Self { 80 | self.transport(Transport::MockHttp) 81 | } 82 | 83 | pub fn transport(mut self, transport: Transport) -> Self { 84 | self.config.transport = Some(transport); 85 | self 86 | } 87 | 88 | pub fn save_cookies(mut self) -> Self { 89 | self.config.save_cookies = true; 90 | self 91 | } 92 | 93 | pub fn do_not_save_cookies(mut self) -> Self { 94 | self.config.save_cookies = false; 95 | self 96 | } 97 | 98 | pub fn default_content_type(mut self, content_type: &str) -> Self { 99 | self.config.default_content_type = Some(content_type.to_string()); 100 | self 101 | } 102 | 103 | pub fn default_scheme(mut self, scheme: &str) -> Self { 104 | self.config.default_scheme = Some(scheme.to_string()); 105 | self 106 | } 107 | 108 | pub fn expect_success_by_default(mut self) -> Self { 109 | self.config.expect_success_by_default = true; 110 | self 111 | } 112 | 113 | pub fn restrict_requests_with_http_schema(mut self) -> Self { 114 | self.config.restrict_requests_with_http_schema = true; 115 | self 116 | } 117 | 118 | /// For turning this into a [`crate::TestServerConfig`] object, 119 | /// with can be passed to [`crate::TestServer::new_with_config`]. 120 | /// 121 | /// ```rust 122 | /// # async fn test() -> Result<(), Box> { 123 | /// # 124 | /// use axum::Router; 125 | /// use axum_test::TestServer; 126 | /// 127 | /// let my_app = Router::new(); 128 | /// let config = TestServer::builder() 129 | /// .save_cookies() 130 | /// .default_content_type(&"application/json") 131 | /// .into_config(); 132 | /// 133 | /// // Build the Test Server 134 | /// let server = TestServer::new_with_config(my_app, config)?; 135 | /// # 136 | /// # Ok(()) 137 | /// # } 138 | /// ``` 139 | pub fn into_config(self) -> TestServerConfig { 140 | self.config 141 | } 142 | 143 | /// Creates a new [`crate::TestServer`], running the application given, 144 | /// and with all settings from this `TestServerBuilder` applied. 145 | /// 146 | /// ```rust 147 | /// use axum::Router; 148 | /// use axum_test::TestServer; 149 | /// 150 | /// let app = Router::new(); 151 | /// let server = TestServer::builder() 152 | /// .save_cookies() 153 | /// .default_content_type(&"application/json") 154 | /// .build(app); 155 | /// ``` 156 | /// 157 | /// This is the equivalent to building [`crate::TestServerConfig`] yourself, 158 | /// and calling [`crate::TestServer::new_with_config`]. 159 | pub fn build(self, app: A) -> Result 160 | where 161 | A: IntoTransportLayer, 162 | { 163 | self.into_config().build(app) 164 | } 165 | } 166 | 167 | impl Default for TestServerBuilder { 168 | fn default() -> Self { 169 | Self { 170 | config: TestServerConfig::default(), 171 | } 172 | } 173 | } 174 | 175 | impl From for TestServerBuilder { 176 | fn from(config: TestServerConfig) -> Self { 177 | TestServerBuilder::from_config(config) 178 | } 179 | } 180 | 181 | #[cfg(test)] 182 | mod test_build { 183 | use super::*; 184 | use std::net::Ipv4Addr; 185 | 186 | #[test] 187 | fn it_should_build_default_config_by_default() { 188 | let config = TestServer::builder().into_config(); 189 | let expected = TestServerConfig::default(); 190 | 191 | assert_eq!(config, expected); 192 | } 193 | 194 | #[test] 195 | fn it_should_save_cookies_when_set() { 196 | let config = TestServer::builder().save_cookies().into_config(); 197 | 198 | assert_eq!(config.save_cookies, true); 199 | } 200 | 201 | #[test] 202 | fn it_should_not_save_cookies_when_set() { 203 | let config = TestServer::builder().do_not_save_cookies().into_config(); 204 | 205 | assert_eq!(config.save_cookies, false); 206 | } 207 | 208 | #[test] 209 | fn it_should_mock_transport_when_set() { 210 | let config = TestServer::builder().mock_transport().into_config(); 211 | 212 | assert_eq!(config.transport, Some(Transport::MockHttp)); 213 | } 214 | 215 | #[test] 216 | fn it_should_use_random_http_transport_when_set() { 217 | let config = TestServer::builder().http_transport().into_config(); 218 | 219 | assert_eq!(config.transport, Some(Transport::HttpRandomPort)); 220 | } 221 | 222 | #[test] 223 | fn it_should_use_http_transport_with_ip_port_when_set() { 224 | let config = TestServer::builder() 225 | .http_transport_with_ip_port(Some(IpAddr::V4(Ipv4Addr::new(123, 4, 5, 6))), Some(987)) 226 | .into_config(); 227 | 228 | assert_eq!( 229 | config.transport, 230 | Some(Transport::HttpIpPort { 231 | ip: Some(IpAddr::V4(Ipv4Addr::new(123, 4, 5, 6))), 232 | port: Some(987), 233 | }) 234 | ); 235 | } 236 | 237 | #[test] 238 | fn it_should_set_default_content_type_when_set() { 239 | let config = TestServer::builder() 240 | .default_content_type("text/csv") 241 | .into_config(); 242 | 243 | assert_eq!(config.default_content_type, Some("text/csv".to_string())); 244 | } 245 | 246 | #[test] 247 | fn it_should_set_default_scheme_when_set() { 248 | let config = TestServer::builder().default_scheme("ftps").into_config(); 249 | 250 | assert_eq!(config.default_scheme, Some("ftps".to_string())); 251 | } 252 | 253 | #[test] 254 | fn it_should_set_expect_success_by_default_when_set() { 255 | let config = TestServer::builder() 256 | .expect_success_by_default() 257 | .into_config(); 258 | 259 | assert_eq!(config.expect_success_by_default, true); 260 | } 261 | 262 | #[test] 263 | fn it_should_set_restrict_requests_with_http_schema_when_set() { 264 | let config = TestServer::builder() 265 | .restrict_requests_with_http_schema() 266 | .into_config(); 267 | 268 | assert_eq!(config.restrict_requests_with_http_schema, true); 269 | } 270 | } 271 | -------------------------------------------------------------------------------- /src/test_server_config.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | 3 | use crate::transport_layer::IntoTransportLayer; 4 | use crate::TestServer; 5 | use crate::TestServerBuilder; 6 | use crate::Transport; 7 | 8 | /// This is for customising the [`TestServer`](crate::TestServer) on construction. 9 | /// It implements [`Default`] to ease building. 10 | /// 11 | /// ```rust 12 | /// use axum_test::TestServerConfig; 13 | /// 14 | /// let config = TestServerConfig { 15 | /// save_cookies: true, 16 | /// ..TestServerConfig::default() 17 | /// }; 18 | /// ``` 19 | /// 20 | /// These can be passed to `TestServer::new_with_config`: 21 | /// 22 | /// ```rust 23 | /// # async fn test() -> Result<(), Box> { 24 | /// # 25 | /// use axum::Router; 26 | /// use axum_test::TestServer; 27 | /// use axum_test::TestServerConfig; 28 | /// 29 | /// let my_app = Router::new(); 30 | /// 31 | /// let config = TestServerConfig { 32 | /// save_cookies: true, 33 | /// ..TestServerConfig::default() 34 | /// }; 35 | /// 36 | /// // Build the Test Server 37 | /// let server = TestServer::new_with_config(my_app, config)?; 38 | /// # 39 | /// # Ok(()) 40 | /// # } 41 | /// ``` 42 | /// 43 | #[derive(Debug, Clone, Eq, PartialEq)] 44 | pub struct TestServerConfig { 45 | /// Which transport mode to use to process requests. 46 | /// For setting if the server should use mocked http (which uses [`tower::util::Oneshot`](tower::util::Oneshot)), 47 | /// or if it should run on a named or random IP address. 48 | /// 49 | /// The default is to use mocking, apart from services built using [`axum::extract::connect_info::IntoMakeServiceWithConnectInfo`](axum::extract::connect_info::IntoMakeServiceWithConnectInfo) 50 | /// (this is because it needs a real TCP stream). 51 | pub transport: Option, 52 | 53 | /// Set for the server to save cookies that are returned, 54 | /// for use in future requests. 55 | /// 56 | /// This is useful for automatically saving session cookies (and similar) 57 | /// like a browser would do. 58 | /// 59 | /// **Defaults** to false (being turned off). 60 | pub save_cookies: bool, 61 | 62 | /// Asserts that requests made to the test server, 63 | /// will by default, 64 | /// return a status code in the 2xx range. 65 | /// 66 | /// This can be overridden on a per request basis using 67 | /// [`TestRequest::expect_failure()`](crate::TestRequest::expect_failure()). 68 | /// 69 | /// This is useful when making multiple requests at a start of test 70 | /// which you presume should always work. 71 | /// 72 | /// **Defaults** to false (being turned off). 73 | pub expect_success_by_default: bool, 74 | 75 | /// If you make a request with a 'http://' schema, 76 | /// then it will ignore the Test Server's address. 77 | /// 78 | /// For example if the test server is running at `http://localhost:1234`, 79 | /// and you make a request to `http://google.com`. 80 | /// Then the request will go to `http://google.com`. 81 | /// Ignoring the `localhost:1234` part. 82 | /// 83 | /// Turning this setting on will change this behaviour. 84 | /// 85 | /// After turning this on, the same request will go to 86 | /// `http://localhost:1234/http://google.com`. 87 | /// 88 | /// **Defaults** to false (being turned off). 89 | pub restrict_requests_with_http_schema: bool, 90 | 91 | /// Set the default content type for all requests created by the `TestServer`. 92 | /// 93 | /// This overrides the default 'best efforts' approach of requests. 94 | pub default_content_type: Option, 95 | 96 | /// Set the default scheme to use for all requests created by the `TestServer`. 97 | /// 98 | /// This overrides the default 'http'. 99 | pub default_scheme: Option, 100 | } 101 | 102 | impl TestServerConfig { 103 | /// Creates a default `TestServerConfig`. 104 | pub fn new() -> Self { 105 | Default::default() 106 | } 107 | 108 | /// This is shorthand for calling [`crate::TestServer::new_with_config`], 109 | /// and passing this config. 110 | /// 111 | /// ```rust 112 | /// # async fn test() -> Result<(), Box> { 113 | /// # 114 | /// use axum::Router; 115 | /// use axum_test::TestServer; 116 | /// use axum_test::TestServerConfig; 117 | /// 118 | /// let app = Router::new(); 119 | /// let config = TestServerConfig { 120 | /// save_cookies: true, 121 | /// default_content_type: Some("application/json".to_string()), 122 | /// ..Default::default() 123 | /// }; 124 | /// let server = TestServer::new_with_config(app, config)?; 125 | /// # 126 | /// # Ok(()) 127 | /// # } 128 | /// ``` 129 | pub fn build(self, app: A) -> Result 130 | where 131 | A: IntoTransportLayer, 132 | { 133 | TestServer::new_with_config(app, self) 134 | } 135 | } 136 | 137 | impl Default for TestServerConfig { 138 | fn default() -> Self { 139 | Self { 140 | transport: None, 141 | save_cookies: false, 142 | expect_success_by_default: false, 143 | restrict_requests_with_http_schema: false, 144 | default_content_type: None, 145 | default_scheme: None, 146 | } 147 | } 148 | } 149 | 150 | impl From for TestServerConfig { 151 | fn from(builder: TestServerBuilder) -> Self { 152 | builder.into_config() 153 | } 154 | } 155 | 156 | #[cfg(test)] 157 | mod test_scheme { 158 | use axum::extract::Request; 159 | use axum::routing::get; 160 | use axum::Router; 161 | 162 | use crate::TestServer; 163 | use crate::TestServerConfig; 164 | 165 | async fn route_get_scheme(request: Request) -> String { 166 | request.uri().scheme_str().unwrap().to_string() 167 | } 168 | 169 | #[tokio::test] 170 | async fn it_should_set_scheme_when_present_in_config() { 171 | let router = Router::new().route("/scheme", get(route_get_scheme)); 172 | 173 | let config = TestServerConfig { 174 | default_scheme: Some("https".to_string()), 175 | ..Default::default() 176 | }; 177 | let server = TestServer::new_with_config(router, config).unwrap(); 178 | 179 | server.get("/scheme").await.assert_text("https"); 180 | } 181 | } 182 | -------------------------------------------------------------------------------- /src/test_web_socket.rs: -------------------------------------------------------------------------------- 1 | use crate::WsMessage; 2 | use anyhow::anyhow; 3 | use anyhow::Context; 4 | use anyhow::Result; 5 | use bytes::Bytes; 6 | use futures_util::sink::SinkExt; 7 | use futures_util::stream::StreamExt; 8 | use hyper::upgrade::Upgraded; 9 | use hyper_util::rt::TokioIo; 10 | use serde::de::DeserializeOwned; 11 | use serde::Serialize; 12 | use std::fmt::Debug; 13 | use std::fmt::Display; 14 | use tokio_tungstenite::tungstenite::protocol::Role; 15 | use tokio_tungstenite::WebSocketStream; 16 | 17 | #[cfg(feature = "pretty-assertions")] 18 | use pretty_assertions::assert_eq; 19 | 20 | #[derive(Debug)] 21 | pub struct TestWebSocket { 22 | stream: WebSocketStream>, 23 | } 24 | 25 | impl TestWebSocket { 26 | pub(crate) async fn new(upgraded: Upgraded) -> Self { 27 | let upgraded_io = TokioIo::new(upgraded); 28 | let stream = WebSocketStream::from_raw_socket(upgraded_io, Role::Client, None).await; 29 | 30 | Self { stream } 31 | } 32 | 33 | pub async fn close(mut self) { 34 | self.stream 35 | .close(None) 36 | .await 37 | .expect("Failed to close WebSocket stream"); 38 | } 39 | 40 | pub async fn send_text(&mut self, raw_text: T) 41 | where 42 | T: Display, 43 | { 44 | let text = format!("{}", raw_text); 45 | self.send_message(WsMessage::Text(text.into())).await; 46 | } 47 | 48 | pub async fn send_json(&mut self, body: &J) 49 | where 50 | J: ?Sized + Serialize, 51 | { 52 | let raw_json = 53 | ::serde_json::to_string(body).expect("It should serialize the content into Json"); 54 | 55 | self.send_message(WsMessage::Text(raw_json.into())).await; 56 | } 57 | 58 | #[cfg(feature = "yaml")] 59 | pub async fn send_yaml(&mut self, body: &Y) 60 | where 61 | Y: ?Sized + Serialize, 62 | { 63 | let raw_yaml = 64 | ::serde_yaml::to_string(body).expect("It should serialize the content into Yaml"); 65 | 66 | self.send_message(WsMessage::Text(raw_yaml.into())).await; 67 | } 68 | 69 | #[cfg(feature = "msgpack")] 70 | pub async fn send_msgpack(&mut self, body: &M) 71 | where 72 | M: ?Sized + Serialize, 73 | { 74 | let body_bytes = 75 | ::rmp_serde::to_vec(body).expect("It should serialize the content into MsgPack"); 76 | 77 | self.send_message(WsMessage::Binary(body_bytes.into())) 78 | .await; 79 | } 80 | 81 | pub async fn send_message(&mut self, message: WsMessage) { 82 | self.stream.send(message).await.unwrap(); 83 | } 84 | 85 | #[must_use] 86 | pub async fn receive_text(&mut self) -> String { 87 | let message = self.receive_message().await; 88 | 89 | message_to_text(message) 90 | .context("Failed to read message as a String") 91 | .unwrap() 92 | } 93 | 94 | #[must_use] 95 | pub async fn receive_json(&mut self) -> T 96 | where 97 | T: DeserializeOwned, 98 | { 99 | let bytes = self.receive_bytes().await; 100 | serde_json::from_slice::(&bytes) 101 | .context("Failed to deserialize message as Json") 102 | .unwrap() 103 | } 104 | 105 | #[cfg(feature = "yaml")] 106 | #[must_use] 107 | pub async fn receive_yaml(&mut self) -> T 108 | where 109 | T: DeserializeOwned, 110 | { 111 | let bytes = self.receive_bytes().await; 112 | serde_yaml::from_slice::(&bytes) 113 | .context("Failed to deserialize message as Yaml") 114 | .unwrap() 115 | } 116 | 117 | #[cfg(feature = "msgpack")] 118 | #[must_use] 119 | pub async fn receive_msgpack(&mut self) -> T 120 | where 121 | T: DeserializeOwned, 122 | { 123 | let received_bytes = self.receive_bytes().await; 124 | rmp_serde::from_slice::(&received_bytes) 125 | .context("Failed to deserializing message as MsgPack") 126 | .unwrap() 127 | } 128 | 129 | #[must_use] 130 | pub async fn receive_bytes(&mut self) -> Bytes { 131 | let message = self.receive_message().await; 132 | 133 | message_to_bytes(message) 134 | .context("Failed to read message as a Bytes") 135 | .unwrap() 136 | } 137 | 138 | #[must_use] 139 | pub async fn receive_message(&mut self) -> WsMessage { 140 | self.maybe_receive_message() 141 | .await 142 | .expect("No message found on WebSocket stream") 143 | } 144 | 145 | pub async fn assert_receive_json(&mut self, expected: &T) 146 | where 147 | T: DeserializeOwned + PartialEq + Debug, 148 | { 149 | assert_eq!(*expected, self.receive_json::().await); 150 | } 151 | 152 | pub async fn assert_receive_text(&mut self, expected: C) 153 | where 154 | C: AsRef, 155 | { 156 | let expected_contents = expected.as_ref(); 157 | assert_eq!(expected_contents, &self.receive_text().await); 158 | } 159 | 160 | pub async fn assert_receive_text_contains(&mut self, expected: C) 161 | where 162 | C: AsRef, 163 | { 164 | let expected_contents = expected.as_ref(); 165 | let received = self.receive_text().await; 166 | let is_contained = received.contains(expected_contents); 167 | 168 | assert!( 169 | is_contained, 170 | "Failed to find '{expected_contents}', received '{received}'" 171 | ); 172 | } 173 | 174 | #[cfg(feature = "yaml")] 175 | pub async fn assert_receive_yaml(&mut self, expected: &T) 176 | where 177 | T: DeserializeOwned + PartialEq + Debug, 178 | { 179 | assert_eq!(*expected, self.receive_yaml::().await); 180 | } 181 | 182 | #[cfg(feature = "msgpack")] 183 | pub async fn assert_receive_msgpack(&mut self, expected: &T) 184 | where 185 | T: DeserializeOwned + PartialEq + Debug, 186 | { 187 | assert_eq!(*expected, self.receive_msgpack::().await); 188 | } 189 | 190 | #[must_use] 191 | async fn maybe_receive_message(&mut self) -> Option { 192 | let maybe_message = self.stream.next().await; 193 | 194 | match maybe_message { 195 | None => None, 196 | Some(message_result) => { 197 | let message = 198 | message_result.expect("Failed to receive message from WebSocket stream"); 199 | Some(message) 200 | } 201 | } 202 | } 203 | } 204 | 205 | fn message_to_text(message: WsMessage) -> Result { 206 | let text = match message { 207 | WsMessage::Text(text) => text.to_string(), 208 | WsMessage::Binary(data) => { 209 | String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())? 210 | } 211 | WsMessage::Ping(data) => { 212 | String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())? 213 | } 214 | WsMessage::Pong(data) => { 215 | String::from_utf8(data.to_vec()).map_err(|err| err.utf8_error())? 216 | } 217 | WsMessage::Close(None) => String::new(), 218 | WsMessage::Close(Some(frame)) => frame.reason.to_string(), 219 | WsMessage::Frame(_) => { 220 | return Err(anyhow!( 221 | "Unexpected Frame, did not expect Frame message whilst reading" 222 | )) 223 | } 224 | }; 225 | 226 | Ok(text) 227 | } 228 | 229 | fn message_to_bytes(message: WsMessage) -> Result { 230 | let bytes = match message { 231 | WsMessage::Text(string) => string.into(), 232 | WsMessage::Binary(data) => data, 233 | WsMessage::Ping(data) => data, 234 | WsMessage::Pong(data) => data, 235 | WsMessage::Close(None) => Bytes::new(), 236 | WsMessage::Close(Some(frame)) => frame.reason.into(), 237 | WsMessage::Frame(_) => { 238 | return Err(anyhow!( 239 | "Unexpected Frame, did not expect Frame message whilst reading" 240 | )) 241 | } 242 | }; 243 | 244 | Ok(bytes) 245 | } 246 | 247 | #[cfg(test)] 248 | mod test_assert_receive_text { 249 | use crate::TestServer; 250 | 251 | use axum::extract::ws::Message; 252 | use axum::extract::ws::WebSocket; 253 | use axum::extract::WebSocketUpgrade; 254 | use axum::response::Response; 255 | use axum::routing::get; 256 | use axum::Router; 257 | 258 | fn new_test_app() -> TestServer { 259 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 260 | async fn handle_ping_pong(mut socket: WebSocket) { 261 | while let Some(maybe_message) = socket.recv().await { 262 | let message_text = maybe_message.unwrap().into_text().unwrap(); 263 | 264 | let encoded_text = format!("Text: {message_text}").try_into().unwrap(); 265 | let encoded_data = format!("Binary: {message_text}").into_bytes().into(); 266 | 267 | socket.send(Message::Text(encoded_text)).await.unwrap(); 268 | socket.send(Message::Binary(encoded_data)).await.unwrap(); 269 | } 270 | } 271 | 272 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 273 | } 274 | 275 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 276 | TestServer::builder().http_transport().build(app).unwrap() 277 | } 278 | 279 | #[tokio::test] 280 | async fn it_should_ping_pong_text_in_text_and_binary() { 281 | let server = new_test_app(); 282 | 283 | let mut websocket = server 284 | .get_websocket(&"/ws-ping-pong") 285 | .await 286 | .into_websocket() 287 | .await; 288 | 289 | websocket.send_text("Hello World!").await; 290 | 291 | websocket.assert_receive_text("Text: Hello World!").await; 292 | websocket.assert_receive_text("Binary: Hello World!").await; 293 | } 294 | 295 | #[tokio::test] 296 | async fn it_should_ping_pong_large_text_blobs() { 297 | const LARGE_BLOB_SIZE: usize = 16777200; // Max websocket size (16mb) - 16 bytes for the 'Text: ' in the reply. 298 | let large_blob = (0..LARGE_BLOB_SIZE).map(|_| "X").collect::(); 299 | 300 | let server = new_test_app(); 301 | let mut websocket = server 302 | .get_websocket(&"/ws-ping-pong") 303 | .await 304 | .into_websocket() 305 | .await; 306 | 307 | websocket.send_text(&large_blob).await; 308 | 309 | websocket 310 | .assert_receive_text(format!("Text: {large_blob}")) 311 | .await; 312 | websocket 313 | .assert_receive_text(format!("Binary: {large_blob}")) 314 | .await; 315 | } 316 | 317 | #[tokio::test] 318 | #[should_panic] 319 | async fn it_should_not_match_partial_text_match() { 320 | let server = new_test_app(); 321 | 322 | let mut websocket = server 323 | .get_websocket(&"/ws-ping-pong") 324 | .await 325 | .into_websocket() 326 | .await; 327 | 328 | websocket.send_text("Hello World!").await; 329 | websocket.assert_receive_text("Hello World!").await; 330 | } 331 | 332 | #[tokio::test] 333 | #[should_panic] 334 | async fn it_should_not_match_different_text() { 335 | let server = new_test_app(); 336 | 337 | let mut websocket = server 338 | .get_websocket(&"/ws-ping-pong") 339 | .await 340 | .into_websocket() 341 | .await; 342 | 343 | websocket.send_text("Hello World!").await; 344 | websocket.assert_receive_text("🦊").await; 345 | } 346 | } 347 | 348 | #[cfg(test)] 349 | mod test_assert_receive_text_contains { 350 | use crate::TestServer; 351 | 352 | use axum::extract::ws::Message; 353 | use axum::extract::ws::WebSocket; 354 | use axum::extract::WebSocketUpgrade; 355 | use axum::response::Response; 356 | use axum::routing::get; 357 | use axum::Router; 358 | 359 | fn new_test_app() -> TestServer { 360 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 361 | async fn handle_ping_pong(mut socket: WebSocket) { 362 | while let Some(maybe_message) = socket.recv().await { 363 | let message_text = maybe_message.unwrap().into_text().unwrap(); 364 | let encoded_text = format!("Text: {message_text}").try_into().unwrap(); 365 | 366 | socket.send(Message::Text(encoded_text)).await.unwrap(); 367 | } 368 | } 369 | 370 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 371 | } 372 | 373 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 374 | TestServer::builder().http_transport().build(app).unwrap() 375 | } 376 | 377 | #[tokio::test] 378 | async fn it_should_assert_whole_text_match() { 379 | let server = new_test_app(); 380 | 381 | let mut websocket = server 382 | .get_websocket(&"/ws-ping-pong") 383 | .await 384 | .into_websocket() 385 | .await; 386 | 387 | websocket.send_text("Hello World!").await; 388 | websocket 389 | .assert_receive_text_contains("Text: Hello World!") 390 | .await; 391 | } 392 | 393 | #[tokio::test] 394 | async fn it_should_assert_partial_text_match() { 395 | let server = new_test_app(); 396 | 397 | let mut websocket = server 398 | .get_websocket(&"/ws-ping-pong") 399 | .await 400 | .into_websocket() 401 | .await; 402 | 403 | websocket.send_text("Hello World!").await; 404 | websocket.assert_receive_text_contains("Hello World!").await; 405 | } 406 | 407 | #[tokio::test] 408 | #[should_panic] 409 | async fn it_should_not_match_different_text() { 410 | let server = new_test_app(); 411 | 412 | let mut websocket = server 413 | .get_websocket(&"/ws-ping-pong") 414 | .await 415 | .into_websocket() 416 | .await; 417 | 418 | websocket.send_text("Hello World!").await; 419 | websocket.assert_receive_text_contains("🦊").await; 420 | } 421 | } 422 | 423 | #[cfg(test)] 424 | mod test_assert_receive_json { 425 | use crate::TestServer; 426 | 427 | use axum::extract::ws::Message; 428 | use axum::extract::ws::WebSocket; 429 | use axum::extract::WebSocketUpgrade; 430 | use axum::response::Response; 431 | use axum::routing::get; 432 | use axum::Router; 433 | use serde_json::json; 434 | use serde_json::Value; 435 | 436 | fn new_test_app() -> TestServer { 437 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 438 | async fn handle_ping_pong(mut socket: WebSocket) { 439 | while let Some(maybe_message) = socket.recv().await { 440 | let message_text = maybe_message.unwrap().into_text().unwrap(); 441 | let decoded = serde_json::from_str::(&message_text).unwrap(); 442 | 443 | let encoded_text = serde_json::to_string(&json!({ 444 | "format": "text", 445 | "message": decoded 446 | })) 447 | .unwrap() 448 | .try_into() 449 | .unwrap(); 450 | let encoded_data = serde_json::to_vec(&json!({ 451 | "format": "binary", 452 | "message": decoded 453 | })) 454 | .unwrap() 455 | .into(); 456 | 457 | socket.send(Message::Text(encoded_text)).await.unwrap(); 458 | socket.send(Message::Binary(encoded_data)).await.unwrap(); 459 | } 460 | } 461 | 462 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 463 | } 464 | 465 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 466 | TestServer::builder().http_transport().build(app).unwrap() 467 | } 468 | 469 | #[tokio::test] 470 | async fn it_should_ping_pong_json_in_text_and_binary() { 471 | let server = new_test_app(); 472 | 473 | let mut websocket = server 474 | .get_websocket(&"/ws-ping-pong") 475 | .await 476 | .into_websocket() 477 | .await; 478 | 479 | websocket 480 | .send_json(&json!({ 481 | "hello": "world", 482 | "numbers": [1, 2, 3], 483 | })) 484 | .await; 485 | 486 | // Once for text 487 | websocket 488 | .assert_receive_json(&json!({ 489 | "format": "text", 490 | "message": { 491 | "hello": "world", 492 | "numbers": [1, 2, 3], 493 | }, 494 | })) 495 | .await; 496 | 497 | // Again for binary 498 | websocket 499 | .assert_receive_json(&json!({ 500 | "format": "binary", 501 | "message": { 502 | "hello": "world", 503 | "numbers": [1, 2, 3], 504 | }, 505 | })) 506 | .await; 507 | } 508 | } 509 | 510 | #[cfg(feature = "yaml")] 511 | #[cfg(test)] 512 | mod test_assert_receive_yaml { 513 | use crate::TestServer; 514 | 515 | use axum::extract::ws::Message; 516 | use axum::extract::ws::WebSocket; 517 | use axum::extract::WebSocketUpgrade; 518 | use axum::response::Response; 519 | use axum::routing::get; 520 | use axum::Router; 521 | use serde_json::json; 522 | use serde_json::Value; 523 | 524 | fn new_test_app() -> TestServer { 525 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 526 | async fn handle_ping_pong(mut socket: WebSocket) { 527 | while let Some(maybe_message) = socket.recv().await { 528 | let message_text = maybe_message.unwrap().into_text().unwrap(); 529 | let decoded = serde_yaml::from_str::(&message_text).unwrap(); 530 | 531 | let encoded_text = serde_yaml::to_string(&json!({ 532 | "format": "text", 533 | "message": decoded 534 | })) 535 | .unwrap() 536 | .try_into() 537 | .unwrap(); 538 | let encoded_data = serde_yaml::to_string(&json!({ 539 | "format": "binary", 540 | "message": decoded 541 | })) 542 | .unwrap() 543 | .into(); 544 | 545 | socket.send(Message::Text(encoded_text)).await.unwrap(); 546 | socket.send(Message::Binary(encoded_data)).await.unwrap(); 547 | } 548 | } 549 | 550 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 551 | } 552 | 553 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 554 | TestServer::builder().http_transport().build(app).unwrap() 555 | } 556 | 557 | #[tokio::test] 558 | async fn it_should_ping_pong_yaml_in_text_and_binary() { 559 | let server = new_test_app(); 560 | 561 | let mut websocket = server 562 | .get_websocket(&"/ws-ping-pong") 563 | .await 564 | .into_websocket() 565 | .await; 566 | 567 | websocket 568 | .send_json(&json!({ 569 | "hello": "world", 570 | "numbers": [1, 2, 3], 571 | })) 572 | .await; 573 | 574 | // Once for text 575 | websocket 576 | .assert_receive_yaml(&json!({ 577 | "format": "text", 578 | "message": { 579 | "hello": "world", 580 | "numbers": [1, 2, 3], 581 | }, 582 | })) 583 | .await; 584 | 585 | // Again for binary 586 | websocket 587 | .assert_receive_yaml(&json!({ 588 | "format": "binary", 589 | "message": { 590 | "hello": "world", 591 | "numbers": [1, 2, 3], 592 | }, 593 | })) 594 | .await; 595 | } 596 | } 597 | 598 | #[cfg(feature = "msgpack")] 599 | #[cfg(test)] 600 | mod test_assert_receive_msgpack { 601 | use crate::TestServer; 602 | 603 | use axum::extract::ws::Message; 604 | use axum::extract::ws::WebSocket; 605 | use axum::extract::WebSocketUpgrade; 606 | use axum::response::Response; 607 | use axum::routing::get; 608 | use axum::Router; 609 | use serde_json::json; 610 | use serde_json::Value; 611 | 612 | fn new_test_app() -> TestServer { 613 | pub async fn route_get_websocket_ping_pong(ws: WebSocketUpgrade) -> Response { 614 | async fn handle_ping_pong(mut socket: WebSocket) { 615 | while let Some(maybe_message) = socket.recv().await { 616 | let message_data = maybe_message.unwrap().into_data(); 617 | let decoded = rmp_serde::from_slice::(&message_data).unwrap(); 618 | 619 | let encoded_data = ::rmp_serde::to_vec(&json!({ 620 | "format": "binary", 621 | "message": decoded 622 | })) 623 | .unwrap() 624 | .into(); 625 | 626 | socket.send(Message::Binary(encoded_data)).await.unwrap(); 627 | } 628 | } 629 | 630 | ws.on_upgrade(move |socket| handle_ping_pong(socket)) 631 | } 632 | 633 | let app = Router::new().route(&"/ws-ping-pong", get(route_get_websocket_ping_pong)); 634 | TestServer::builder().http_transport().build(app).unwrap() 635 | } 636 | 637 | #[tokio::test] 638 | async fn it_should_ping_pong_msgpack_in_binary() { 639 | let server = new_test_app(); 640 | 641 | let mut websocket = server 642 | .get_websocket(&"/ws-ping-pong") 643 | .await 644 | .into_websocket() 645 | .await; 646 | 647 | websocket 648 | .send_msgpack(&json!({ 649 | "hello": "world", 650 | "numbers": [1, 2, 3], 651 | })) 652 | .await; 653 | 654 | websocket 655 | .assert_receive_msgpack(&json!({ 656 | "format": "binary", 657 | "message": { 658 | "hello": "world", 659 | "numbers": [1, 2, 3], 660 | }, 661 | })) 662 | .await; 663 | } 664 | } 665 | -------------------------------------------------------------------------------- /src/transport.rs: -------------------------------------------------------------------------------- 1 | use std::net::IpAddr; 2 | 3 | /// Transport is for setting which transport mode for the `TestServer` 4 | /// to use when making requests. 5 | #[derive(Debug, Copy, Clone, Eq, PartialEq)] 6 | pub enum Transport { 7 | /// With this transport mode, `TestRequest` will use a mock HTTP 8 | /// transport. 9 | /// 10 | /// This is the Default Transport type. 11 | MockHttp, 12 | 13 | /// With this transport mode, a real web server will be spun up 14 | /// running on a random port. Requests made using the `TestRequest` 15 | /// will be made over the network stack. 16 | HttpRandomPort, 17 | 18 | /// With this transport mode, a real web server will be spun up. 19 | /// Where you can pick which IP and Port to use for this to bind to. 20 | /// 21 | /// Setting both `ip` and `port` to `None`, is the equivalent of 22 | /// using `Transport::HttpRandomPort`. 23 | HttpIpPort { 24 | /// Set the IP to use for the server. 25 | /// 26 | /// **Defaults** to `127.0.0.1`. 27 | ip: Option, 28 | 29 | /// Set the port number to use for the server. 30 | /// 31 | /// **Defaults** to a _random_ port. 32 | port: Option, 33 | }, 34 | } 35 | 36 | impl Default for Transport { 37 | fn default() -> Self { 38 | Self::MockHttp 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | 3 | use crate::transport_layer::TransportLayer; 4 | use crate::transport_layer::TransportLayerBuilder; 5 | 6 | // mod into_make_service_tower; 7 | 8 | mod into_make_service; 9 | mod into_make_service_with_connect_info; 10 | mod router; 11 | mod serve; 12 | mod with_graceful_shutdown; 13 | 14 | #[cfg(feature = "shuttle")] 15 | mod axum_service; 16 | #[cfg(feature = "shuttle")] 17 | mod shuttle_axum; 18 | 19 | /// 20 | /// This exists to unify how to send mock or real messages to different services. 21 | /// This includes differences between [`Router`](::axum::Router), 22 | /// [`IntoMakeService`](::axum::routing::IntoMakeService), 23 | /// and [`IntoMakeServiceWithConnectInfo`](::axum::extract::connect_info::IntoMakeServiceWithConnectInfo). 24 | /// 25 | /// Implementing this will allow you to use the `TestServer` against other types. 26 | /// 27 | /// **Warning**, this trait may change in a future release. 28 | /// 29 | pub trait IntoTransportLayer: Sized { 30 | fn into_http_transport_layer( 31 | self, 32 | builder: TransportLayerBuilder, 33 | ) -> Result>; 34 | 35 | fn into_mock_transport_layer(self) -> Result>; 36 | 37 | fn into_default_transport( 38 | self, 39 | _builder: TransportLayerBuilder, 40 | ) -> Result> { 41 | self.into_mock_transport_layer() 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/axum_service.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::Router; 3 | use shuttle_axum::AxumService; 4 | 5 | use crate::transport_layer::IntoTransportLayer; 6 | use crate::transport_layer::TransportLayer; 7 | use crate::transport_layer::TransportLayerBuilder; 8 | 9 | impl IntoTransportLayer for AxumService { 10 | fn into_http_transport_layer( 11 | self, 12 | builder: TransportLayerBuilder, 13 | ) -> Result> { 14 | Router::into_http_transport_layer(self.0, builder) 15 | } 16 | 17 | fn into_mock_transport_layer(self) -> Result> { 18 | Router::into_mock_transport_layer(self.0) 19 | } 20 | } 21 | 22 | #[cfg(test)] 23 | mod test_into_http_transport_layer_for_axum_service { 24 | use super::*; 25 | 26 | use axum::extract::State; 27 | use axum::routing::get; 28 | use axum::Router; 29 | 30 | use crate::TestServer; 31 | 32 | async fn get_state(State(count): State) -> String { 33 | format!("count is {}", count) 34 | } 35 | 36 | #[tokio::test] 37 | async fn it_should_run() { 38 | // Build an application with a route. 39 | let app: AxumService = Router::new() 40 | .route("/count", get(get_state)) 41 | .with_state(123) 42 | .into(); 43 | 44 | // Run the server. 45 | let server = TestServer::builder() 46 | .http_transport() 47 | .build(app) 48 | .expect("Should create test server"); 49 | 50 | // Get the request. 51 | server.get(&"/count").await.assert_text(&"count is 123"); 52 | } 53 | } 54 | 55 | #[cfg(test)] 56 | mod test_into_mock_transport_layer_for_axum_service { 57 | use super::*; 58 | 59 | use axum::extract::State; 60 | use axum::routing::get; 61 | use axum::Router; 62 | 63 | use crate::TestServer; 64 | 65 | async fn get_state(State(count): State) -> String { 66 | format!("count is {}", count) 67 | } 68 | 69 | #[tokio::test] 70 | async fn it_should_run() { 71 | // Build an application with a route. 72 | let app: AxumService = Router::new() 73 | .route("/count", get(get_state)) 74 | .with_state(123) 75 | .into(); 76 | 77 | // Run the server. 78 | let server = TestServer::builder() 79 | .mock_transport() 80 | .build(app) 81 | .expect("Should create test server"); 82 | 83 | // Get the request. 84 | server.get(&"/count").await.assert_text(&"count is 123"); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/into_make_service.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::extract::Request as AxumRequest; 3 | use axum::response::Response as AxumResponse; 4 | use axum::routing::IntoMakeService; 5 | use std::convert::Infallible; 6 | use tower::Service; 7 | use url::Url; 8 | 9 | use crate::internals::HttpTransportLayer; 10 | use crate::internals::MockTransportLayer; 11 | use crate::transport_layer::IntoTransportLayer; 12 | use crate::transport_layer::TransportLayer; 13 | use crate::transport_layer::TransportLayerBuilder; 14 | use crate::util::spawn_serve; 15 | 16 | impl IntoTransportLayer for IntoMakeService 17 | where 18 | S: Service 19 | + Clone 20 | + Send 21 | + Sync 22 | + 'static, 23 | S::Future: Send, 24 | { 25 | fn into_http_transport_layer( 26 | self, 27 | builder: TransportLayerBuilder, 28 | ) -> Result> { 29 | let (socket_addr, tcp_listener, maybe_reserved_port) = 30 | builder.tcp_listener_with_reserved_port()?; 31 | 32 | let serve_handle = spawn_serve(tcp_listener, self); 33 | let server_address = format!("http://{socket_addr}"); 34 | let server_url: Url = server_address.parse()?; 35 | 36 | Ok(Box::new(HttpTransportLayer::new( 37 | serve_handle, 38 | maybe_reserved_port, 39 | server_url, 40 | ))) 41 | } 42 | 43 | fn into_mock_transport_layer(self) -> Result> { 44 | let transport_layer = MockTransportLayer::new(self); 45 | Ok(Box::new(transport_layer)) 46 | } 47 | } 48 | 49 | #[cfg(test)] 50 | mod test_into_http_transport_layer_for_into_make_service { 51 | use crate::TestServer; 52 | use axum::extract::Request; 53 | use axum::extract::State; 54 | use axum::routing::get; 55 | use axum::Router; 56 | use axum::ServiceExt; 57 | use tower::Layer; 58 | use tower_http::normalize_path::NormalizePathLayer; 59 | 60 | async fn get_ping() -> &'static str { 61 | "pong!" 62 | } 63 | 64 | async fn get_state(State(count): State) -> String { 65 | format!("count is {}", count) 66 | } 67 | 68 | #[tokio::test] 69 | async fn it_should_create_and_test_with_make_into_service() { 70 | // Build an application with a route. 71 | let app = Router::new() 72 | .route("/ping", get(get_ping)) 73 | .into_make_service(); 74 | 75 | // Run the server. 76 | let server = TestServer::builder() 77 | .http_transport() 78 | .build(app) 79 | .expect("Should create test server"); 80 | 81 | // Get the request. 82 | server.get(&"/ping").await.assert_text(&"pong!"); 83 | } 84 | 85 | #[tokio::test] 86 | async fn it_should_create_and_test_with_make_into_service_with_state() { 87 | // Build an application with a route. 88 | let app = Router::new() 89 | .route("/count", get(get_state)) 90 | .with_state(123) 91 | .into_make_service(); 92 | 93 | // Run the server. 94 | let server = TestServer::builder() 95 | .http_transport() 96 | .build(app) 97 | .expect("Should create test server"); 98 | 99 | // Get the request. 100 | server.get(&"/count").await.assert_text(&"count is 123"); 101 | } 102 | 103 | #[tokio::test] 104 | async fn it_should_create_and_run_with_router_wrapped_service() { 105 | // Build an application with a route. 106 | let router = Router::new() 107 | .route("/count", get(get_state)) 108 | .with_state(123); 109 | let normalized_router = NormalizePathLayer::trim_trailing_slash().layer(router); 110 | let app = ServiceExt::::into_make_service(normalized_router); 111 | 112 | // Run the server. 113 | let server = TestServer::builder() 114 | .http_transport() 115 | .build(app) 116 | .expect("Should create test server"); 117 | 118 | // Get the request. 119 | server.get(&"/count").await.assert_text(&"count is 123"); 120 | } 121 | } 122 | 123 | #[cfg(test)] 124 | mod test_into_mock_transport_layer_for_into_make_service { 125 | use crate::TestServer; 126 | use axum::extract::Request; 127 | use axum::extract::State; 128 | use axum::routing::get; 129 | use axum::Router; 130 | use axum::ServiceExt; 131 | use tower::Layer; 132 | use tower_http::normalize_path::NormalizePathLayer; 133 | 134 | async fn get_ping() -> &'static str { 135 | "pong!" 136 | } 137 | 138 | async fn get_state(State(count): State) -> String { 139 | format!("count is {}", count) 140 | } 141 | 142 | #[tokio::test] 143 | async fn it_should_create_and_test_with_make_into_service() { 144 | // Build an application with a route. 145 | let app = Router::new() 146 | .route("/ping", get(get_ping)) 147 | .into_make_service(); 148 | 149 | // Run the server. 150 | let server = TestServer::builder() 151 | .mock_transport() 152 | .build(app) 153 | .expect("Should create test server"); 154 | 155 | // Get the request. 156 | server.get(&"/ping").await.assert_text(&"pong!"); 157 | } 158 | 159 | #[tokio::test] 160 | async fn it_should_create_and_test_with_make_into_service_with_state() { 161 | // Build an application with a route. 162 | let app = Router::new() 163 | .route("/count", get(get_state)) 164 | .with_state(123) 165 | .into_make_service(); 166 | 167 | // Run the server. 168 | let server = TestServer::builder() 169 | .mock_transport() 170 | .build(app) 171 | .expect("Should create test server"); 172 | 173 | // Get the request. 174 | server.get(&"/count").await.assert_text(&"count is 123"); 175 | } 176 | 177 | #[tokio::test] 178 | async fn it_should_create_and_run_with_router_wrapped_service() { 179 | // Build an application with a route. 180 | let router = Router::new() 181 | .route("/count", get(get_state)) 182 | .with_state(123); 183 | let normalized_router = NormalizePathLayer::trim_trailing_slash().layer(router); 184 | let app = ServiceExt::::into_make_service(normalized_router); 185 | 186 | // Run the server. 187 | let server = TestServer::builder() 188 | .mock_transport() 189 | .build(app) 190 | .expect("Should create test server"); 191 | 192 | // Get the request. 193 | server.get(&"/count").await.assert_text(&"count is 123"); 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/into_make_service_with_connect_info.rs: -------------------------------------------------------------------------------- 1 | use crate::internals::HttpTransportLayer; 2 | use crate::transport_layer::IntoTransportLayer; 3 | use crate::transport_layer::TransportLayer; 4 | use crate::transport_layer::TransportLayerBuilder; 5 | use crate::util::spawn_serve; 6 | use anyhow::anyhow; 7 | use anyhow::Result; 8 | use axum::extract::connect_info::IntoMakeServiceWithConnectInfo; 9 | use axum::extract::Request as AxumRequest; 10 | use axum::response::Response as AxumResponse; 11 | use axum::serve::IncomingStream; 12 | use std::convert::Infallible; 13 | use tokio::net::TcpListener; 14 | use tower::Service; 15 | use url::Url; 16 | 17 | impl IntoTransportLayer for IntoMakeServiceWithConnectInfo 18 | where 19 | for<'a> C: axum::extract::connect_info::Connected>, 20 | S: Service + Clone + Send + 'static, 21 | S::Future: Send, 22 | { 23 | fn into_http_transport_layer( 24 | self, 25 | builder: TransportLayerBuilder, 26 | ) -> Result> { 27 | let (socket_addr, tcp_listener, maybe_reserved_port) = 28 | builder.tcp_listener_with_reserved_port()?; 29 | 30 | let serve_handle = spawn_serve(tcp_listener, self); 31 | let server_address = format!("http://{socket_addr}"); 32 | let server_url: Url = server_address.parse()?; 33 | 34 | Ok(Box::new(HttpTransportLayer::new( 35 | serve_handle, 36 | maybe_reserved_port, 37 | server_url, 38 | ))) 39 | } 40 | 41 | fn into_mock_transport_layer(self) -> Result> { 42 | Err(anyhow!("`IntoMakeServiceWithConnectInfo` cannot be mocked, as it's underlying implementation requires a real connection. Set the `TestServerConfig` to run with a transport of `HttpRandomPort`, or a `HttpIpPort`.")) 43 | } 44 | 45 | fn into_default_transport( 46 | self, 47 | builder: TransportLayerBuilder, 48 | ) -> Result> { 49 | self.into_http_transport_layer(builder) 50 | } 51 | } 52 | 53 | #[cfg(test)] 54 | mod test_into_http_transport_layer_for_into_make_service_with_connect_info { 55 | use crate::TestServer; 56 | use axum::extract::Request; 57 | use axum::routing::get; 58 | use axum::Router; 59 | use axum::ServiceExt; 60 | use std::net::SocketAddr; 61 | use tower::Layer; 62 | use tower_http::normalize_path::NormalizePathLayer; 63 | 64 | async fn get_ping() -> &'static str { 65 | "pong!" 66 | } 67 | 68 | #[tokio::test] 69 | async fn it_should_create_and_test_with_make_into_service_with_connect_info() { 70 | // Build an application with a route. 71 | let app = Router::new() 72 | .route("/ping", get(get_ping)) 73 | .into_make_service_with_connect_info::(); 74 | 75 | // Run the server. 76 | let server = TestServer::builder() 77 | .http_transport() 78 | .build(app) 79 | .expect("Should create test server"); 80 | 81 | // Get the request. 82 | server.get(&"/ping").await.assert_text(&"pong!"); 83 | } 84 | 85 | #[tokio::test] 86 | async fn it_should_create_and_run_with_router_wrapped_service() { 87 | // Build an application with a route. 88 | let router = Router::new().route("/ping", get(get_ping)); 89 | let normalized_router = NormalizePathLayer::trim_trailing_slash().layer(router); 90 | let app = ServiceExt::::into_make_service_with_connect_info::( 91 | normalized_router, 92 | ); 93 | 94 | // Run the server. 95 | let server = TestServer::builder() 96 | .http_transport() 97 | .build(app) 98 | .expect("Should create test server"); 99 | 100 | // Get the request. 101 | server.get(&"/ping").await.assert_text(&"pong!"); 102 | } 103 | } 104 | 105 | #[cfg(test)] 106 | mod test_into_mock_transport_layer_for_into_make_service_with_connect_info { 107 | use crate::TestServer; 108 | use axum::routing::get; 109 | use axum::Router; 110 | use std::net::SocketAddr; 111 | 112 | async fn get_ping() -> &'static str { 113 | "pong!" 114 | } 115 | 116 | #[tokio::test] 117 | async fn it_should_panic_when_creating_test_using_mock() { 118 | // Build an application with a route. 119 | let app = Router::new() 120 | .route("/ping", get(get_ping)) 121 | .into_make_service_with_connect_info::(); 122 | 123 | // Build the server. 124 | let result = TestServer::builder().mock_transport().build(app); 125 | let err = result.unwrap_err(); 126 | let err_msg = format!("{}", err); 127 | 128 | assert_eq!(err_msg, "`IntoMakeServiceWithConnectInfo` cannot be mocked, as it's underlying implementation requires a real connection. Set the `TestServerConfig` to run with a transport of `HttpRandomPort`, or a `HttpIpPort`."); 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/router.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::Router; 3 | 4 | use crate::transport_layer::IntoTransportLayer; 5 | use crate::transport_layer::TransportLayer; 6 | use crate::transport_layer::TransportLayerBuilder; 7 | 8 | impl IntoTransportLayer for Router<()> { 9 | fn into_http_transport_layer( 10 | self, 11 | builder: TransportLayerBuilder, 12 | ) -> Result> { 13 | self.into_make_service().into_http_transport_layer(builder) 14 | } 15 | 16 | fn into_mock_transport_layer(self) -> Result> { 17 | self.into_make_service().into_mock_transport_layer() 18 | } 19 | } 20 | 21 | #[cfg(test)] 22 | mod test_into_http_transport_layer { 23 | use axum::extract::State; 24 | use axum::routing::get; 25 | use axum::Router; 26 | 27 | use crate::TestServer; 28 | 29 | async fn get_ping() -> &'static str { 30 | "pong!" 31 | } 32 | 33 | async fn get_state(State(count): State) -> String { 34 | format!("count is {}", count) 35 | } 36 | 37 | #[tokio::test] 38 | async fn it_should_create_and_test_with_make_into_service() { 39 | // Build an application with a route. 40 | let app: Router = Router::new().route("/ping", get(get_ping)); 41 | 42 | // Run the server. 43 | let server = TestServer::builder() 44 | .http_transport() 45 | .build(app) 46 | .expect("Should create test server"); 47 | 48 | // Get the request. 49 | server.get(&"/ping").await.assert_text(&"pong!"); 50 | } 51 | 52 | #[tokio::test] 53 | async fn it_should_create_and_test_with_make_into_service_with_state() { 54 | // Build an application with a route. 55 | let app: Router = Router::new() 56 | .route("/count", get(get_state)) 57 | .with_state(123); 58 | 59 | // Run the server. 60 | let server = TestServer::builder() 61 | .http_transport() 62 | .build(app) 63 | .expect("Should create test server"); 64 | 65 | // Get the request. 66 | server.get(&"/count").await.assert_text(&"count is 123"); 67 | } 68 | } 69 | 70 | #[cfg(test)] 71 | mod test_into_mock_transport_layer_for_router { 72 | use axum::extract::State; 73 | use axum::routing::get; 74 | use axum::Router; 75 | 76 | use crate::TestServer; 77 | 78 | async fn get_ping() -> &'static str { 79 | "pong!" 80 | } 81 | 82 | async fn get_state(State(count): State) -> String { 83 | format!("count is {}", count) 84 | } 85 | 86 | #[tokio::test] 87 | async fn it_should_create_and_test_with_make_into_service() { 88 | // Build an application with a route. 89 | let app: Router = Router::new().route("/ping", get(get_ping)); 90 | 91 | // Run the server. 92 | let server = TestServer::builder() 93 | .mock_transport() 94 | .build(app) 95 | .expect("Should create test server"); 96 | 97 | // Get the request. 98 | server.get(&"/ping").await.assert_text(&"pong!"); 99 | } 100 | 101 | #[tokio::test] 102 | async fn it_should_create_and_test_with_make_into_service_with_state() { 103 | // Build an application with a route. 104 | let app: Router = Router::new() 105 | .route("/count", get(get_state)) 106 | .with_state(123); 107 | 108 | // Run the server. 109 | let server = TestServer::builder() 110 | .mock_transport() 111 | .build(app) 112 | .expect("Should create test server"); 113 | 114 | // Get the request. 115 | server.get(&"/count").await.assert_text(&"count is 123"); 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/serve.rs: -------------------------------------------------------------------------------- 1 | use crate::internals::HttpTransportLayer; 2 | use crate::transport_layer::IntoTransportLayer; 3 | use crate::transport_layer::TransportLayer; 4 | use crate::transport_layer::TransportLayerBuilder; 5 | use crate::util::ServeHandle; 6 | use anyhow::anyhow; 7 | use anyhow::Context; 8 | use anyhow::Result; 9 | use axum::extract::Request; 10 | use axum::response::Response; 11 | use axum::serve::IncomingStream; 12 | use axum::serve::Serve; 13 | use std::convert::Infallible; 14 | use tokio::net::TcpListener; 15 | use tokio::spawn; 16 | use tower::Service; 17 | use url::Url; 18 | 19 | impl IntoTransportLayer for Serve 20 | where 21 | M: for<'a> Service, Error = Infallible, Response = S> 22 | + Send 23 | + 'static, 24 | for<'a> >>::Future: Send, 25 | S: Service + Clone + Send + 'static, 26 | S::Future: Send, 27 | { 28 | fn into_http_transport_layer( 29 | self, 30 | _builder: TransportLayerBuilder, 31 | ) -> Result> { 32 | Err(anyhow!("`Serve` must be started with http or mock transport. Do not set any transport on `TestServerConfig`.")) 33 | } 34 | 35 | fn into_mock_transport_layer(self) -> Result> { 36 | Err(anyhow!("`Serve` cannot be mocked, as it's underlying implementation requires a real connection. Do not set any transport on `TestServerConfig`.")) 37 | } 38 | 39 | fn into_default_transport( 40 | self, 41 | _builder: TransportLayerBuilder, 42 | ) -> Result> { 43 | let socket_addr = self.local_addr()?; 44 | 45 | let join_handle = spawn(async move { 46 | self.await 47 | .context("Failed to create ::axum::Server for TestServer") 48 | .expect("Expect server to start serving"); 49 | }); 50 | 51 | let server_address = format!("http://{socket_addr}"); 52 | let server_url: Url = server_address.parse()?; 53 | 54 | Ok(Box::new(HttpTransportLayer::new( 55 | ServeHandle::new(join_handle), 56 | None, 57 | server_url, 58 | ))) 59 | } 60 | } 61 | 62 | #[cfg(test)] 63 | mod test_into_http_transport_layer { 64 | use crate::util::new_random_tokio_tcp_listener; 65 | use crate::TestServer; 66 | use axum::routing::get; 67 | use axum::routing::IntoMakeService; 68 | use axum::serve; 69 | use axum::Router; 70 | 71 | async fn get_ping() -> &'static str { 72 | "pong!" 73 | } 74 | 75 | #[tokio::test] 76 | #[should_panic] 77 | async fn it_should_panic_when_run_with_http() { 78 | // Build an application with a route. 79 | let app: IntoMakeService = Router::new() 80 | .route("/ping", get(get_ping)) 81 | .into_make_service(); 82 | let port = new_random_tokio_tcp_listener().unwrap(); 83 | let application = serve(port, app); 84 | 85 | // Run the server. 86 | TestServer::builder() 87 | .http_transport() 88 | .build(application) 89 | .expect("Should create test server"); 90 | } 91 | } 92 | 93 | #[cfg(test)] 94 | mod test_into_mock_transport_layer { 95 | use crate::util::new_random_tokio_tcp_listener; 96 | use crate::TestServer; 97 | use axum::routing::get; 98 | use axum::routing::IntoMakeService; 99 | use axum::serve; 100 | use axum::Router; 101 | 102 | async fn get_ping() -> &'static str { 103 | "pong!" 104 | } 105 | 106 | #[tokio::test] 107 | #[should_panic] 108 | async fn it_should_panic_when_run_with_mock_http() { 109 | // Build an application with a route. 110 | let app: IntoMakeService = Router::new() 111 | .route("/ping", get(get_ping)) 112 | .into_make_service(); 113 | let port = new_random_tokio_tcp_listener().unwrap(); 114 | let application = serve(port, app); 115 | 116 | // Run the server. 117 | TestServer::builder() 118 | .mock_transport() 119 | .build(application) 120 | .expect("Should create test server"); 121 | } 122 | } 123 | 124 | #[cfg(test)] 125 | mod test_into_default_transport { 126 | use crate::util::new_random_tokio_tcp_listener; 127 | use crate::TestServer; 128 | use axum::routing::get; 129 | use axum::routing::IntoMakeService; 130 | use axum::serve; 131 | use axum::Router; 132 | 133 | async fn get_ping() -> &'static str { 134 | "pong!" 135 | } 136 | 137 | #[tokio::test] 138 | async fn it_should_run_service() { 139 | // Build an application with a route. 140 | let app: IntoMakeService = Router::new() 141 | .route("/ping", get(get_ping)) 142 | .into_make_service(); 143 | let port = new_random_tokio_tcp_listener().unwrap(); 144 | let application = serve(port, app); 145 | 146 | // Run the server. 147 | let server = TestServer::builder() 148 | .build(application) 149 | .expect("Should create test server"); 150 | 151 | // Get the request. 152 | server.get(&"/ping").await.assert_text(&"pong!"); 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/shuttle_axum.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use shuttle_axum::ShuttleAxum; 3 | 4 | use crate::transport_layer::IntoTransportLayer; 5 | use crate::transport_layer::TransportLayer; 6 | use crate::transport_layer::TransportLayerBuilder; 7 | 8 | impl IntoTransportLayer for ShuttleAxum { 9 | fn into_http_transport_layer( 10 | self, 11 | builder: TransportLayerBuilder, 12 | ) -> Result> { 13 | self.map_err(Into::into) 14 | .and_then(|axum_service| axum_service.into_http_transport_layer(builder)) 15 | } 16 | 17 | fn into_mock_transport_layer(self) -> Result> { 18 | self.map_err(Into::into) 19 | .and_then(|axum_service| axum_service.into_mock_transport_layer()) 20 | } 21 | } 22 | 23 | #[cfg(test)] 24 | mod test_into_http_transport_layer_for_shuttle_axum { 25 | use super::*; 26 | 27 | use axum::extract::State; 28 | use axum::routing::get; 29 | use axum::Router; 30 | use shuttle_axum::AxumService; 31 | 32 | use crate::TestServer; 33 | 34 | async fn get_state(State(count): State) -> String { 35 | format!("count is {}", count) 36 | } 37 | 38 | #[tokio::test] 39 | async fn it_should_run() { 40 | // Build an application with a route. 41 | let router = Router::new() 42 | .route("/count", get(get_state)) 43 | .with_state(123); 44 | let app: ShuttleAxum = Ok(AxumService::from(router)); 45 | 46 | // Run the server. 47 | let server = TestServer::builder() 48 | .http_transport() 49 | .build(app) 50 | .expect("Should create test server"); 51 | 52 | // Get the request. 53 | server.get(&"/count").await.assert_text(&"count is 123"); 54 | } 55 | } 56 | 57 | #[cfg(test)] 58 | mod test_into_mock_transport_layer_for_shuttle_axum { 59 | use super::*; 60 | 61 | use axum::extract::State; 62 | use axum::routing::get; 63 | use axum::Router; 64 | use shuttle_axum::AxumService; 65 | 66 | use crate::TestServer; 67 | 68 | async fn get_state(State(count): State) -> String { 69 | format!("count is {}", count) 70 | } 71 | 72 | #[tokio::test] 73 | async fn it_should_run() { 74 | // Build an application with a route. 75 | let router = Router::new() 76 | .route("/count", get(get_state)) 77 | .with_state(123); 78 | let app: ShuttleAxum = Ok(AxumService::from(router)); 79 | 80 | // Run the server. 81 | let server = TestServer::builder() 82 | .mock_transport() 83 | .build(app) 84 | .expect("Should create test server"); 85 | 86 | // Get the request. 87 | server.get(&"/count").await.assert_text(&"count is 123"); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/transport_layer/into_transport_layer/with_graceful_shutdown.rs: -------------------------------------------------------------------------------- 1 | use crate::internals::HttpTransportLayer; 2 | use crate::transport_layer::IntoTransportLayer; 3 | use crate::transport_layer::TransportLayer; 4 | use crate::transport_layer::TransportLayerBuilder; 5 | use crate::util::ServeHandle; 6 | use anyhow::anyhow; 7 | use anyhow::Context; 8 | use anyhow::Result; 9 | use axum::extract::Request; 10 | use axum::response::Response; 11 | use axum::serve::IncomingStream; 12 | use axum::serve::WithGracefulShutdown; 13 | use std::convert::Infallible; 14 | use std::future::Future; 15 | use tokio::net::TcpListener; 16 | use tokio::spawn; 17 | use tower::Service; 18 | use url::Url; 19 | 20 | impl IntoTransportLayer for WithGracefulShutdown 21 | where 22 | M: for<'a> Service, Error = Infallible, Response = S> 23 | + Send 24 | + 'static, 25 | for<'a> >>::Future: Send, 26 | S: Service + Clone + Send + 'static, 27 | S::Future: Send, 28 | F: Future + Send + 'static, 29 | { 30 | fn into_http_transport_layer( 31 | self, 32 | _builder: TransportLayerBuilder, 33 | ) -> Result> { 34 | Err(anyhow!("`WithGracefulShutdown` must be started with http or mock transport. Do not set any transport on `TestServerConfig`.")) 35 | } 36 | 37 | fn into_mock_transport_layer(self) -> Result> { 38 | Err(anyhow!("`WithGracefulShutdown` cannot be mocked, as it's underlying implementation requires a real connection. Do not set any transport on `TestServerConfig`.")) 39 | } 40 | 41 | fn into_default_transport( 42 | self, 43 | _builder: TransportLayerBuilder, 44 | ) -> Result> { 45 | let socket_addr = self.local_addr()?; 46 | 47 | let join_handle = spawn(async move { 48 | self.await 49 | .context("Failed to create ::axum::Server for TestServer") 50 | .expect("Expect server to start serving"); 51 | }); 52 | 53 | let server_address = format!("http://{socket_addr}"); 54 | let server_url: Url = server_address.parse()?; 55 | 56 | Ok(Box::new(HttpTransportLayer::new( 57 | ServeHandle::new(join_handle), 58 | None, 59 | server_url, 60 | ))) 61 | } 62 | } 63 | 64 | #[cfg(test)] 65 | mod test_into_http_transport_layer { 66 | use crate::util::new_random_tokio_tcp_listener; 67 | use crate::TestServer; 68 | use axum::routing::get; 69 | use axum::routing::IntoMakeService; 70 | use axum::serve; 71 | use axum::Router; 72 | use std::future::pending; 73 | 74 | async fn get_ping() -> &'static str { 75 | "pong!" 76 | } 77 | 78 | #[tokio::test] 79 | #[should_panic] 80 | async fn it_should_panic_when_run_with_http() { 81 | // Build an application with a route. 82 | let app: IntoMakeService = Router::new() 83 | .route("/ping", get(get_ping)) 84 | .into_make_service(); 85 | let port = new_random_tokio_tcp_listener().unwrap(); 86 | let application = serve(port, app).with_graceful_shutdown(pending()); 87 | 88 | // Run the server. 89 | TestServer::builder() 90 | .http_transport() 91 | .build(application) 92 | .expect("Should create test server"); 93 | } 94 | } 95 | 96 | #[cfg(test)] 97 | mod test_into_mock_transport_layer { 98 | use crate::util::new_random_tokio_tcp_listener; 99 | use crate::TestServer; 100 | use axum::routing::get; 101 | use axum::routing::IntoMakeService; 102 | use axum::serve; 103 | use axum::Router; 104 | use std::future::pending; 105 | 106 | async fn get_ping() -> &'static str { 107 | "pong!" 108 | } 109 | 110 | #[tokio::test] 111 | #[should_panic] 112 | async fn it_should_panic_when_run_with_mock_http() { 113 | // Build an application with a route. 114 | let app: IntoMakeService = Router::new() 115 | .route("/ping", get(get_ping)) 116 | .into_make_service(); 117 | let port = new_random_tokio_tcp_listener().unwrap(); 118 | let application = serve(port, app).with_graceful_shutdown(pending()); 119 | 120 | // Run the server. 121 | TestServer::builder() 122 | .mock_transport() 123 | .build(application) 124 | .expect("Should create test server"); 125 | } 126 | } 127 | 128 | #[cfg(test)] 129 | mod test_into_default_transport { 130 | use crate::util::new_random_tokio_tcp_listener; 131 | use crate::TestServer; 132 | use axum::routing::get; 133 | use axum::routing::IntoMakeService; 134 | use axum::serve; 135 | use axum::Router; 136 | use std::future::pending; 137 | 138 | async fn get_ping() -> &'static str { 139 | "pong!" 140 | } 141 | 142 | #[tokio::test] 143 | async fn it_should_run_service() { 144 | // Build an application with a route. 145 | let app: IntoMakeService = Router::new() 146 | .route("/ping", get(get_ping)) 147 | .into_make_service(); 148 | let port = new_random_tokio_tcp_listener().unwrap(); 149 | let application = serve(port, app).with_graceful_shutdown(pending()); 150 | 151 | // Run the server. 152 | let server = TestServer::builder() 153 | .build(application) 154 | .expect("Should create test server"); 155 | 156 | // Get the request. 157 | server.get(&"/ping").await.assert_text(&"pong!"); 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /src/transport_layer/mod.rs: -------------------------------------------------------------------------------- 1 | mod into_transport_layer; 2 | pub use self::into_transport_layer::*; 3 | 4 | mod transport_layer_builder; 5 | pub use self::transport_layer_builder::*; 6 | 7 | mod transport_layer_type; 8 | pub use self::transport_layer_type::*; 9 | 10 | mod transport_layer; 11 | pub use self::transport_layer::*; 12 | -------------------------------------------------------------------------------- /src/transport_layer/transport_layer.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use axum::body::Body; 3 | use http::Request; 4 | use http::Response; 5 | use std::fmt::Debug; 6 | use std::future::Future; 7 | use std::pin::Pin; 8 | use url::Url; 9 | 10 | use crate::transport_layer::TransportLayerType; 11 | 12 | pub trait TransportLayer: Debug + Send + Sync + 'static { 13 | fn send<'a>( 14 | &'a self, 15 | request: Request, 16 | ) -> Pin>>>>; 17 | 18 | fn url(&self) -> Option<&Url> { 19 | None 20 | } 21 | 22 | fn transport_layer_type(&self) -> TransportLayerType; 23 | 24 | fn is_running(&self) -> bool; 25 | } 26 | 27 | #[cfg(test)] 28 | mod test_sync { 29 | use super::*; 30 | use tokio::sync::OnceCell; 31 | 32 | #[test] 33 | fn it_should_compile_with_tokyo_once_cell() { 34 | // if it compiles, it works! 35 | fn _take_tokio_once_cell(layer: T) -> OnceCell> 36 | where 37 | T: TransportLayer, 38 | { 39 | OnceCell::new_with(Some(Box::new(layer))) 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/transport_layer/transport_layer_builder.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use anyhow::Result; 3 | use reserve_port::ReservedPort; 4 | use std::net::IpAddr; 5 | use std::net::SocketAddr; 6 | use tokio::net::TcpListener; 7 | 8 | use crate::internals::StartingTcpSetup; 9 | 10 | #[derive(Debug, Clone)] 11 | pub struct TransportLayerBuilder { 12 | ip: Option, 13 | port: Option, 14 | } 15 | 16 | impl TransportLayerBuilder { 17 | pub(crate) fn new(ip: Option, port: Option) -> Self { 18 | Self { ip, port } 19 | } 20 | 21 | pub(crate) fn tcp_listener_with_reserved_port( 22 | self, 23 | ) -> Result<(SocketAddr, TcpListener, Option)> { 24 | let setup = StartingTcpSetup::new(self.ip, self.port) 25 | .context("Cannot create socket address for use")?; 26 | 27 | let socket_addr = setup.socket_addr; 28 | let tcp_listener = setup.tcp_listener; 29 | let maybe_reserved_port = setup.maybe_reserved_port; 30 | 31 | Ok((socket_addr, tcp_listener, maybe_reserved_port)) 32 | } 33 | 34 | pub fn tcp_listener(self) -> Result { 35 | let (_, tcp_listener, _) = self.tcp_listener_with_reserved_port()?; 36 | Ok(tcp_listener) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/transport_layer/transport_layer_type.rs: -------------------------------------------------------------------------------- 1 | #[derive(Clone, Copy, PartialEq, Eq, Debug)] 2 | pub enum TransportLayerType { 3 | Http, 4 | Mock, 5 | } 6 | -------------------------------------------------------------------------------- /src/util/mod.rs: -------------------------------------------------------------------------------- 1 | mod new_random_port; 2 | pub use self::new_random_port::*; 3 | 4 | mod new_random_socket_addr; 5 | pub use self::new_random_socket_addr::*; 6 | 7 | mod new_random_tcp_listener; 8 | pub use self::new_random_tcp_listener::*; 9 | 10 | mod new_random_tokio_tcp_listener; 11 | pub use self::new_random_tokio_tcp_listener::*; 12 | 13 | mod spawn_serve; 14 | pub use self::spawn_serve::*; 15 | 16 | mod serve_handle; 17 | pub use self::serve_handle::*; 18 | -------------------------------------------------------------------------------- /src/util/new_random_port.rs: -------------------------------------------------------------------------------- 1 | use anyhow::anyhow; 2 | use anyhow::Result; 3 | use reserve_port::ReservedPort; 4 | 5 | /// Returns a randomly selected port that is not in use. 6 | pub fn new_random_port() -> Result { 7 | ReservedPort::random_permanently_reserved().map_err(|_| anyhow!("No free port was found")) 8 | } 9 | -------------------------------------------------------------------------------- /src/util/new_random_socket_addr.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use std::net::IpAddr; 3 | use std::net::Ipv4Addr; 4 | use std::net::SocketAddr; 5 | 6 | use crate::util::new_random_port; 7 | 8 | pub(crate) const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); 9 | 10 | /// Generates a `SocketAddr` on the IP 127.0.0.1, using a random port. 11 | pub fn new_random_socket_addr() -> Result { 12 | let ip_address = DEFAULT_IP_ADDRESS; 13 | let port = new_random_port()?; 14 | let addr = SocketAddr::new(ip_address, port); 15 | 16 | Ok(addr) 17 | } 18 | -------------------------------------------------------------------------------- /src/util/new_random_tcp_listener.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use reserve_port::ReservedPort; 3 | use std::net::IpAddr; 4 | use std::net::Ipv4Addr; 5 | use std::net::SocketAddr; 6 | use std::net::TcpListener; 7 | 8 | pub(crate) const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); 9 | 10 | /// Binds a [`std::net::TcpListener`] on the IP 127.0.0.1, using a random port. 11 | /// 12 | /// This is the best way to pick a local port. 13 | pub fn new_random_tcp_listener() -> Result { 14 | let (tcp_listener, _) = ReservedPort::random_permanently_reserved_tcp(DEFAULT_IP_ADDRESS)?; 15 | Ok(tcp_listener) 16 | } 17 | 18 | /// Binds a [`std::net::TcpListener`] on the IP 127.0.0.1, using a random port. 19 | /// 20 | /// It is returned with the [`std::net::SocketAddr`] available. 21 | pub fn new_random_tcp_listener_with_socket_addr() -> Result<(TcpListener, SocketAddr)> { 22 | let result = ReservedPort::random_permanently_reserved_tcp(DEFAULT_IP_ADDRESS)?; 23 | Ok(result) 24 | } 25 | -------------------------------------------------------------------------------- /src/util/new_random_tokio_tcp_listener.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use reserve_port::ReservedPort; 3 | use std::net::IpAddr; 4 | use std::net::Ipv4Addr; 5 | use std::net::SocketAddr; 6 | use tokio::net::TcpListener as TokioTcpListener; 7 | 8 | pub(crate) const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); 9 | 10 | /// Binds a [`tokio::net::TcpListener`] on the IP 127.0.0.1, using a random port. 11 | /// 12 | /// This is the best way to pick a local port. 13 | pub fn new_random_tokio_tcp_listener() -> Result { 14 | new_random_tokio_tcp_listener_with_socket_addr() 15 | .map(|(tokio_tcp_listener, _)| tokio_tcp_listener) 16 | } 17 | 18 | /// Binds a [`tokio::net::TcpListener`] on the IP 127.0.0.1, using a random port. 19 | /// 20 | /// It is returned with the [`std::net::SocketAddr`] available. 21 | pub fn new_random_tokio_tcp_listener_with_socket_addr() -> Result<(TokioTcpListener, SocketAddr)> { 22 | let (tcp_listener, random_socket) = 23 | ReservedPort::random_permanently_reserved_tcp(DEFAULT_IP_ADDRESS)?; 24 | 25 | tcp_listener.set_nonblocking(true)?; 26 | let tokio_tcp_listener = TokioTcpListener::from_std(tcp_listener)?; 27 | 28 | Ok((tokio_tcp_listener, random_socket)) 29 | } 30 | -------------------------------------------------------------------------------- /src/util/serve_handle.rs: -------------------------------------------------------------------------------- 1 | use tokio::task::JoinHandle; 2 | 3 | /// A handle to a running Axum service. 4 | /// 5 | /// When the handle is dropped, it will attempt to terminate the service. 6 | #[derive(Debug)] 7 | pub struct ServeHandle { 8 | server_handle: JoinHandle<()>, 9 | } 10 | 11 | impl ServeHandle { 12 | pub(crate) fn new(server_handle: JoinHandle<()>) -> Self { 13 | Self { server_handle } 14 | } 15 | 16 | pub fn is_finished(&self) -> bool { 17 | self.server_handle.is_finished() 18 | } 19 | } 20 | 21 | impl Drop for ServeHandle { 22 | fn drop(&mut self) { 23 | self.server_handle.abort() 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/util/spawn_serve.rs: -------------------------------------------------------------------------------- 1 | use crate::util::ServeHandle; 2 | use axum::extract::Request; 3 | use axum::response::Response; 4 | use axum::serve; 5 | use axum::serve::IncomingStream; 6 | use axum::serve::Listener; 7 | use core::fmt::Debug; 8 | use std::convert::Infallible; 9 | use tokio::spawn; 10 | use tower::Service; 11 | 12 | /// A wrapper around [`axum::serve()`] for tests, 13 | /// which spawns the service in a new thread. 14 | /// 15 | /// The [`crate::util::ServeHandle`] returned will automatically attempt 16 | /// to terminate the service when dropped. 17 | pub fn spawn_serve(tcp_listener: L, make_service: M) -> ServeHandle 18 | where 19 | L: Listener, 20 | L::Addr: Debug, 21 | M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, 22 | for<'a> >>::Future: Send, 23 | S: Service + Clone + Send + 'static, 24 | S::Future: Send, 25 | { 26 | let server_handle = spawn(async move { 27 | serve(tcp_listener, make_service) 28 | .await 29 | .expect("Expect server to start serving"); 30 | }); 31 | 32 | ServeHandle::new(server_handle) 33 | } 34 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | cargo +stable check 6 | cargo +stable test --example=example-shuttle --features shuttle 7 | cargo +stable test --example=example-todo 8 | cargo +stable test --example=example-websocket-ping-pong --features ws 9 | cargo +stable test --example=example-websocket-chat --features ws 10 | cargo +stable test --features all "$@" 11 | cargo +stable test "$@" 12 | 13 | # Check minimum version works, excluding shuttle 14 | cargo +1.83 check --features "pretty-assertions,yaml,msgpack,reqwest,typed-routing,ws" 15 | # Check nightly also works, see https://github.com/JosephLenton/axum-test/issues/133 16 | cargo +nightly check --features all "$@" 17 | 18 | # Check the various build variations work 19 | cargo +stable check --no-default-features 20 | cargo +stable check --features all 21 | cargo +stable check --features pretty-assertions 22 | cargo +stable check --features yaml 23 | cargo +stable check --features msgpack 24 | cargo +stable check --features reqwest 25 | cargo +stable check --features shuttle 26 | cargo +stable check --features typed-routing 27 | cargo +stable check --features ws 28 | cargo +stable check --features reqwest 29 | cargo +stable check --features old-json-diff 30 | 31 | cargo +stable clippy --features all 32 | -------------------------------------------------------------------------------- /tests/test-expect-json-integration.rs: -------------------------------------------------------------------------------- 1 | use axum_test::expect_json::expect_op; 2 | use axum_test::expect_json::ExpectOp; 3 | 4 | // If it compiles, it works! 5 | #[test] 6 | fn test_expect_op_for_axum_test_integration_compiles() { 7 | #[expect_op] 8 | #[derive(Debug, Clone)] 9 | pub struct Testing; 10 | 11 | impl ExpectOp for Testing {} 12 | 13 | assert!(true); 14 | } 15 | --------------------------------------------------------------------------------