├── .github └── workflows │ ├── pr.yaml │ └── release.yaml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── examples └── example.rs ├── src ├── config.rs ├── error.rs ├── extension.rs ├── layer.rs ├── lib.rs ├── marker.rs ├── state.rs └── tx.rs └── tests └── lib.rs /.github/workflows/pr.yaml: -------------------------------------------------------------------------------- 1 | name: pr 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | clippy: 10 | name: clippy 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: actions-rs/toolchain@v1 15 | with: 16 | toolchain: stable 17 | override: true 18 | components: clippy 19 | - uses: actions-rs/cargo@v1 20 | with: 21 | command: clippy 22 | args: -- -D warnings 23 | 24 | doc: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v2 28 | - uses: actions-rs/toolchain@v1 29 | with: 30 | toolchain: stable 31 | override: true 32 | - uses: actions-rs/cargo@v1 33 | with: 34 | command: doc 35 | 36 | fmt: 37 | runs-on: ubuntu-latest 38 | steps: 39 | - uses: actions/checkout@v2 40 | - uses: actions-rs/toolchain@v1 41 | with: 42 | toolchain: stable 43 | override: true 44 | components: rustfmt 45 | - uses: actions-rs/cargo@v1 46 | with: 47 | command: fmt 48 | args: -- --check 49 | 50 | test: 51 | runs-on: ubuntu-latest 52 | steps: 53 | - uses: actions/checkout@v2 54 | - uses: actions-rs/toolchain@v1 55 | with: 56 | toolchain: stable 57 | profile: default 58 | override: true 59 | - uses: actions-rs/cargo@v1 60 | with: 61 | command: test 62 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | push: 5 | tags: '[0-9]+.[0-9]+.[0-9]+' 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions-rs/toolchain@v1 13 | with: 14 | toolchain: stable 15 | override: true 16 | - uses: actions-rs/cargo@v1 17 | with: 18 | command: publish 19 | args: --no-verify 20 | env: 21 | CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.sublime-workspace 2 | /target 3 | Cargo.lock 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "axum-sqlx-tx" 3 | description = "Request-scoped SQLx transactions for axum" 4 | version = "0.10.0" 5 | license = "MIT" 6 | repository = "https://github.com/digital-society-coop/axum-sqlx-tx/" 7 | edition = "2021" 8 | include = [ 9 | "LICENSE", 10 | "README.md", 11 | "Cargo.toml", 12 | "**/*.rs" 13 | ] 14 | 15 | [dependencies] 16 | axum-core = "0.5" 17 | bytes = "1" 18 | futures-core = "0.3" 19 | http = "1" 20 | http-body = "1" 21 | parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] } 22 | sqlx = { version = "0.8", default-features = false } 23 | thiserror = "1" 24 | tower-layer = "0.3" 25 | tower-service = "0.3" 26 | 27 | [dev-dependencies] 28 | axum = "0.8.1" 29 | hyper = "1.0.1" 30 | sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] } 31 | tokio = { version = "1.17.0", features = ["macros", "rt-multi-thread"] } 32 | tower = "0.5.2" 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 axum-sqlx-tx authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `axum-sqlx-tx` 2 | 3 | Request-bound [SQLx](https://github.com/launchbadge/sqlx) transactions for [axum](https://github.com/tokio-rs/axum). 4 | 5 | ## Summary 6 | 7 | `axum-sqlx-tx` provides an `axum` [extractor](https://docs.rs/axum/latest/axum/#extractors) for obtaining a request-bound transaction. 8 | The transaction begins the first time the extractor is used, and is stored with the request for use by other middleware/handlers. 9 | The transaction is resolved depending on the status code of the response – successful (`2XX`) responses will commit the transaction, otherwise it will be rolled back. 10 | 11 | See the [crate documentation](https://docs.rs/axum-sqlx-tx) for more information and examples. 12 | -------------------------------------------------------------------------------- /examples/example.rs: -------------------------------------------------------------------------------- 1 | //! A silly server that generates random numers, but only commits positive ones. 2 | 3 | use std::error::Error; 4 | 5 | use axum::{response::IntoResponse, routing::get, Json}; 6 | use http::StatusCode; 7 | 8 | // Recommended: use a type alias to avoid repeating your database type 9 | type Tx = axum_sqlx_tx::Tx; 10 | 11 | #[tokio::main] 12 | async fn main() -> Result<(), Box> { 13 | // You can use any sqlx::Pool 14 | let pool = sqlx::SqlitePool::connect("sqlite::memory:").await?; 15 | 16 | // Create a table (in a real application you might run migrations) 17 | sqlx::query("CREATE TABLE IF NOT EXISTS numbers (number INT PRIMARY KEY);") 18 | .execute(&pool) 19 | .await?; 20 | 21 | let (state, layer) = Tx::setup(pool); 22 | 23 | // Standard axum app setup 24 | let app = axum::Router::new() 25 | .route("/numbers", get(list_numbers).post(generate_number)) 26 | // Apply the Tx middleware 27 | .layer(layer) 28 | // Add the Tx state 29 | .with_state(state); 30 | 31 | let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); 32 | println!("Listening on {}", listener.local_addr().unwrap()); 33 | 34 | axum::serve(listener, app).await?; 35 | 36 | Ok(()) 37 | } 38 | 39 | async fn list_numbers(mut tx: Tx) -> Result>, DbError> { 40 | let numbers: Vec<(i32,)> = sqlx::query_as("SELECT * FROM numbers") 41 | .fetch_all(&mut tx) 42 | .await?; 43 | 44 | Ok(Json(numbers.into_iter().map(|n| n.0).collect())) 45 | } 46 | 47 | async fn generate_number(mut tx: Tx) -> Result<(StatusCode, Json), DbError> { 48 | let (number,): (i32,) = 49 | sqlx::query_as("INSERT INTO numbers VALUES (random()) RETURNING number;") 50 | .fetch_one(&mut tx) 51 | .await?; 52 | 53 | // Simulate a possible error – in reality this could be something like interacting with another 54 | // service, or running another query. 55 | let status = if number > 0 { 56 | StatusCode::OK 57 | } else { 58 | StatusCode::IM_A_TEAPOT 59 | }; 60 | 61 | // no need to explicitly resolve! 62 | Ok((status, Json(number))) 63 | } 64 | 65 | // An sqlx::Error wrapper that implements IntoResponse 66 | struct DbError(sqlx::Error); 67 | 68 | impl From for DbError { 69 | fn from(error: sqlx::Error) -> Self { 70 | Self(error) 71 | } 72 | } 73 | 74 | impl IntoResponse for DbError { 75 | fn into_response(self) -> axum::response::Response { 76 | println!("ERROR: {}", self.0); 77 | (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/config.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use crate::{Layer, Marker, State}; 4 | 5 | /// Configuration for [`Tx`](crate::Tx) extractors. 6 | /// 7 | /// Use `Config` to configure and create a [`State`] and [`Layer`]. 8 | /// 9 | /// Access the `Config` API from [`Tx::config`](crate::Tx::config). 10 | /// 11 | /// ``` 12 | /// # async fn foo() { 13 | /// # let pool: sqlx::SqlitePool = todo!(); 14 | /// type Tx = axum_sqlx_tx::Tx; 15 | /// 16 | /// let config = Tx::config(pool); 17 | /// # } 18 | /// ``` 19 | pub struct Config { 20 | pool: sqlx::Pool, 21 | _layer_error: PhantomData, 22 | } 23 | 24 | impl Config 25 | where 26 | LayerError: axum_core::response::IntoResponse, 27 | sqlx::Error: Into, 28 | { 29 | pub(crate) fn new(pool: sqlx::Pool) -> Self { 30 | Self { 31 | pool, 32 | _layer_error: PhantomData, 33 | } 34 | } 35 | 36 | /// Change the layer error type. 37 | pub fn layer_error(self) -> Config 38 | where 39 | sqlx::Error: Into, 40 | { 41 | Config { 42 | pool: self.pool, 43 | _layer_error: PhantomData, 44 | } 45 | } 46 | 47 | /// Create a [`State`] and [`Layer`] to enable the [`Tx`](crate::Tx) extractor. 48 | pub fn setup(self) -> (State, Layer) { 49 | let state = State::new(self.pool); 50 | let layer = Layer::new(state.clone()); 51 | (state, layer) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | /// Possible errors when extracting [`Tx`] from a request. 2 | /// 3 | /// Errors can occur at two points during the request lifecycle: 4 | /// 5 | /// 1. The [`Tx`] extractor might fail to obtain a connection from the pool and `BEGIN` a 6 | /// transaction. This could be due to: 7 | /// 8 | /// - Forgetting to add the middleware: [`Error::MissingExtension`]. 9 | /// - Calling the extractor multiple times in the same request: [`Error::OverlappingExtractors`]. 10 | /// - A problem communicating with the database: [`Error::Database`]. 11 | /// 12 | /// 2. The middleware [`Layer`] might fail to commit the transaction. This could be due to a problem 13 | /// communicating with the database, or else a logic error (e.g. unsatisfied deferred 14 | /// constraint): [`Error::Database`]. 15 | /// 16 | /// `axum` requires that errors can be turned into responses. The [`Error`] type converts into a 17 | /// HTTP 500 response with the error message as the response body. This may be suitable for 18 | /// development or internal services but it's generally not advisable to return internal error 19 | /// details to clients. 20 | /// 21 | /// You can override the error types for both the [`Tx`] extractor and [`Layer`]: 22 | /// 23 | /// - Override the [`Tx`]`` error type using the `E` generic type parameter. `E` must be 24 | /// convertible from [`Error`] (e.g. [`Error`]`: Into`). 25 | /// 26 | /// - Override the [`Layer`] error type using [`Config::layer_error`](crate::Config::layer_error). 27 | /// The layer error type must be convertible from `sqlx::Error` (e.g. 28 | /// `sqlx::Error: Into`). 29 | /// 30 | /// In both cases, the error type must implement `axum::response::IntoResponse`. 31 | /// 32 | /// ``` 33 | /// use axum::{response::IntoResponse, routing::post}; 34 | /// 35 | /// enum MyError{ 36 | /// Extractor(axum_sqlx_tx::Error), 37 | /// Layer(sqlx::Error), 38 | /// } 39 | /// 40 | /// impl From for MyError { 41 | /// fn from(error: axum_sqlx_tx::Error) -> Self { 42 | /// Self::Extractor(error) 43 | /// } 44 | /// } 45 | /// 46 | /// impl From for MyError { 47 | /// fn from(error: sqlx::Error) -> Self { 48 | /// Self::Layer(error) 49 | /// } 50 | /// } 51 | /// 52 | /// impl IntoResponse for MyError { 53 | /// fn into_response(self) -> axum::response::Response { 54 | /// // note that you would probably want to log the error as well 55 | /// (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() 56 | /// } 57 | /// } 58 | /// 59 | /// // Override the `Tx` error type using the second generic type parameter 60 | /// type Tx = axum_sqlx_tx::Tx; 61 | /// 62 | /// # async fn foo() { 63 | /// let pool = sqlx::SqlitePool::connect("...").await.unwrap(); 64 | /// 65 | /// let (state, layer) = Tx::config(pool) 66 | /// // Override the `Layer` error type using the `Config` API 67 | /// .layer_error::() 68 | /// .setup(); 69 | /// # let app = axum::Router::new() 70 | /// # .route("/", post(create_user)) 71 | /// # .layer(layer) 72 | /// # .with_state(state); 73 | /// # let listener: tokio::net::TcpListener = todo!(); 74 | /// # axum::serve(listener, app); 75 | /// # } 76 | /// # async fn create_user(mut tx: Tx, /* ... */) { 77 | /// # /* ... */ 78 | /// # } 79 | /// ``` 80 | /// 81 | /// [`Tx`]: crate::Tx 82 | /// [`Layer`]: crate::Layer 83 | #[derive(Debug, thiserror::Error)] 84 | pub enum Error { 85 | /// Indicates that the [`Layer`](crate::Layer) middleware was not installed. 86 | #[error("required extension not registered; did you add the axum_sqlx_tx::Layer middleware?")] 87 | MissingExtension, 88 | 89 | /// Indicates that [`Tx`](crate::Tx) was extracted multiple times in a single 90 | /// handler/middleware. 91 | #[error("axum_sqlx_tx::Tx extractor used multiple times in the same handler/middleware")] 92 | OverlappingExtractors, 93 | 94 | /// A database error occurred when starting or committing the transaction. 95 | #[error(transparent)] 96 | Database { 97 | #[from] 98 | error: sqlx::Error, 99 | }, 100 | } 101 | 102 | impl axum_core::response::IntoResponse for Error { 103 | fn into_response(self) -> axum_core::response::Response { 104 | (http::StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/extension.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use parking_lot::{lock_api::ArcMutexGuard, Mutex, RawMutex}; 4 | use sqlx::Transaction; 5 | 6 | use crate::{Error, Marker, State}; 7 | 8 | /// The request extension. 9 | pub(crate) struct Extension { 10 | slot: Arc>>, 11 | } 12 | 13 | impl Extension { 14 | pub(crate) fn new(state: State) -> Self { 15 | let slot = Arc::new(Mutex::new(LazyTransaction::new(state))); 16 | Self { slot } 17 | } 18 | 19 | pub(crate) async fn acquire( 20 | &self, 21 | ) -> Result>, Error> { 22 | let mut tx = self 23 | .slot 24 | .try_lock_arc() 25 | .ok_or(Error::OverlappingExtractors)?; 26 | tx.acquire().await?; 27 | 28 | Ok(tx) 29 | } 30 | 31 | pub(crate) async fn resolve(&self) -> Result<(), sqlx::Error> { 32 | if let Some(mut tx) = self.slot.try_lock_arc() { 33 | tx.resolve().await?; 34 | } 35 | Ok(()) 36 | } 37 | } 38 | 39 | impl Clone for Extension { 40 | fn clone(&self) -> Self { 41 | Self { 42 | slot: self.slot.clone(), 43 | } 44 | } 45 | } 46 | 47 | /// The lazy transaction. 48 | pub(crate) struct LazyTransaction(LazyTransactionState); 49 | 50 | enum LazyTransactionState { 51 | Unacquired { 52 | state: State, 53 | }, 54 | Acquired { 55 | tx: Transaction<'static, DB::Driver>, 56 | }, 57 | Resolved, 58 | } 59 | 60 | impl LazyTransaction { 61 | fn new(state: State) -> Self { 62 | Self(LazyTransactionState::Unacquired { state }) 63 | } 64 | 65 | pub(crate) fn as_ref(&self) -> &Transaction<'static, DB::Driver> { 66 | match &self.0 { 67 | LazyTransactionState::Unacquired { .. } => { 68 | panic!("BUG: exposed unacquired LazyTransaction") 69 | } 70 | LazyTransactionState::Acquired { tx } => tx, 71 | LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"), 72 | } 73 | } 74 | 75 | pub(crate) fn as_mut(&mut self) -> &mut Transaction<'static, DB::Driver> { 76 | match &mut self.0 { 77 | LazyTransactionState::Unacquired { .. } => { 78 | panic!("BUG: exposed unacquired LazyTransaction") 79 | } 80 | LazyTransactionState::Acquired { tx } => tx, 81 | LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"), 82 | } 83 | } 84 | 85 | async fn acquire(&mut self) -> Result<(), Error> { 86 | match &self.0 { 87 | LazyTransactionState::Unacquired { state } => { 88 | let tx = state.transaction().await?; 89 | self.0 = LazyTransactionState::Acquired { tx }; 90 | Ok(()) 91 | } 92 | LazyTransactionState::Acquired { .. } => Ok(()), 93 | LazyTransactionState::Resolved => Err(Error::OverlappingExtractors), 94 | } 95 | } 96 | 97 | pub(crate) async fn resolve(&mut self) -> Result<(), sqlx::Error> { 98 | match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) { 99 | LazyTransactionState::Unacquired { .. } | LazyTransactionState::Resolved => Ok(()), 100 | LazyTransactionState::Acquired { tx } => tx.commit().await, 101 | } 102 | } 103 | 104 | pub(crate) async fn commit(&mut self) -> Result<(), sqlx::Error> { 105 | match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) { 106 | LazyTransactionState::Unacquired { .. } => { 107 | panic!("BUG: tried to commit unacquired transaction") 108 | } 109 | LazyTransactionState::Acquired { tx } => tx.commit().await, 110 | LazyTransactionState::Resolved => panic!("BUG: tried to commit resolved transaction"), 111 | } 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /src/layer.rs: -------------------------------------------------------------------------------- 1 | //! A [`tower_layer::Layer`] that enables the [`Tx`](crate::Tx) extractor. 2 | 3 | use std::marker::PhantomData; 4 | 5 | use axum_core::response::IntoResponse; 6 | use bytes::Bytes; 7 | use futures_core::future::BoxFuture; 8 | use http_body::Body; 9 | 10 | use crate::{extension::Extension, Marker, State}; 11 | 12 | /// A [`tower_layer::Layer`] that enables the [`Tx`] extractor. 13 | /// 14 | /// This layer adds a lazily-initialised transaction to the [request extensions]. The first time the 15 | /// [`Tx`] extractor is used on a request, a connection is acquired from the configured 16 | /// [`sqlx::Pool`] and a transaction is started on it. The same transaction will be returned for 17 | /// subsequent uses of [`Tx`] on the same request. The inner service is then called as normal. Once 18 | /// the inner service responds, the transaction is committed or rolled back depending on the status 19 | /// code of the response. 20 | /// 21 | /// [`Tx`]: crate::Tx 22 | /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html 23 | pub struct Layer { 24 | state: State, 25 | _error: PhantomData, 26 | } 27 | 28 | impl Layer 29 | where 30 | E: IntoResponse, 31 | sqlx::Error: Into, 32 | { 33 | pub(crate) fn new(state: State) -> Self { 34 | Self { 35 | state, 36 | _error: PhantomData, 37 | } 38 | } 39 | } 40 | 41 | impl Clone for Layer { 42 | fn clone(&self) -> Self { 43 | Self { 44 | state: self.state.clone(), 45 | _error: self._error, 46 | } 47 | } 48 | } 49 | 50 | impl tower_layer::Layer for Layer 51 | where 52 | E: IntoResponse, 53 | sqlx::Error: Into, 54 | { 55 | type Service = Service; 56 | 57 | fn layer(&self, inner: S) -> Self::Service { 58 | Service { 59 | state: self.state.clone(), 60 | inner, 61 | _error: self._error, 62 | } 63 | } 64 | } 65 | 66 | /// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor. 67 | /// 68 | /// See [`Layer`] for more information. 69 | pub struct Service { 70 | state: State, 71 | inner: S, 72 | _error: PhantomData, 73 | } 74 | 75 | // can't simply derive because `DB` isn't `Clone` 76 | impl Clone for Service { 77 | fn clone(&self) -> Self { 78 | Self { 79 | state: self.state.clone(), 80 | inner: self.inner.clone(), 81 | _error: self._error, 82 | } 83 | } 84 | } 85 | 86 | impl tower_service::Service> 87 | for Service 88 | where 89 | S: tower_service::Service< 90 | http::Request, 91 | Response = http::Response, 92 | Error = std::convert::Infallible, 93 | >, 94 | S::Future: Send + 'static, 95 | E: IntoResponse, 96 | sqlx::Error: Into, 97 | ResBody: Body + Send + 'static, 98 | ResBody::Error: Into>, 99 | { 100 | type Response = http::Response; 101 | type Error = S::Error; 102 | type Future = BoxFuture<'static, Result>; 103 | 104 | fn poll_ready( 105 | &mut self, 106 | cx: &mut std::task::Context<'_>, 107 | ) -> std::task::Poll> { 108 | self.inner.poll_ready(cx).map_err(|err| match err {}) 109 | } 110 | 111 | fn call(&mut self, mut req: http::Request) -> Self::Future { 112 | let ext = Extension::new(self.state.clone()); 113 | req.extensions_mut().insert(ext.clone()); 114 | 115 | let res = self.inner.call(req); 116 | 117 | Box::pin(async move { 118 | let res = res.await.unwrap(); // inner service is infallible 119 | 120 | if !res.status().is_server_error() && !res.status().is_client_error() { 121 | if let Err(error) = ext.resolve().await { 122 | return Ok(error.into().into_response()); 123 | } 124 | } 125 | 126 | Ok(res.map(axum_core::body::Body::new)) 127 | }) 128 | } 129 | } 130 | 131 | #[cfg(test)] 132 | mod tests { 133 | use tokio::net::TcpListener; 134 | 135 | use crate::{Error, State}; 136 | 137 | use super::Layer; 138 | 139 | // The trait shenanigans required by axum for layers are significant, so this "test" ensures 140 | // we've got it right. 141 | #[allow(unused, unreachable_code, clippy::diverging_sub_expression)] 142 | fn layer_compiles() { 143 | let state: State = todo!(); 144 | 145 | let layer = Layer::<_, Error>::new(state); 146 | 147 | let app = axum::Router::new() 148 | .route("/", axum::routing::get(|| async { "hello" })) 149 | .layer(layer); 150 | 151 | let listener: TcpListener = todo!(); 152 | axum::serve(listener, app); 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Request-bound [SQLx] transactions for [axum]. 2 | //! 3 | //! [SQLx]: https://github.com/launchbadge/sqlx#readme 4 | //! [axum]: https://github.com/tokio-rs/axum#readme 5 | //! 6 | //! [`Tx`] is an `axum` [extractor][axum extractors] for obtaining a transaction that's bound to the 7 | //! HTTP request. A transaction begins the first time the extractor is used for a request, and is 8 | //! then stored in [request extensions] for use by other middleware/handlers. The transaction is 9 | //! resolved depending on the status code of the eventual response – successful (HTTP `2XX` or 10 | //! `3XX`) responses will cause the transaction to be committed, otherwise it will be rolled back. 11 | //! 12 | //! This behaviour is often a sensible default, and using the extractor (e.g. rather than directly 13 | //! using [`sqlx::Transaction`]s) means you can't forget to commit the transactions! 14 | //! 15 | //! [axum extractors]: https://docs.rs/axum/latest/axum/#extractors 16 | //! [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html 17 | //! 18 | //! # Usage 19 | //! 20 | //! To use the [`Tx`] extractor, you must first add [`State`] and [`Layer`] to your app. [`State`] 21 | //! holds the configuration for the extractor, and the [`Layer`] middleware manages the 22 | //! request-bound transaction. 23 | //! 24 | //! ``` 25 | //! # async fn foo() { 26 | //! // It's recommended to create aliases specialised for your extractor(s) 27 | //! type Tx = axum_sqlx_tx::Tx; 28 | //! 29 | //! let pool = sqlx::SqlitePool::connect("...").await.unwrap(); 30 | //! 31 | //! let (state, layer) = Tx::setup(pool); 32 | //! 33 | //! let app = axum::Router::new() 34 | //! // .route(...)s 35 | //! # .route("/", axum::routing::get(|tx: Tx| async move {})) 36 | //! .layer(layer) 37 | //! .with_state(state); 38 | //! # let listener: tokio::net::TcpListener = todo!(); 39 | //! # axum::serve(listener, app); 40 | //! # } 41 | //! ``` 42 | //! 43 | //! You can then simply add [`Tx`] as an argument to your handlers: 44 | //! 45 | //! ``` 46 | //! type Tx = axum_sqlx_tx::Tx; 47 | //! 48 | //! async fn create_user(mut tx: Tx, /* ... */) { 49 | //! // `&mut Tx` implements `sqlx::Executor` 50 | //! let user = sqlx::query("INSERT INTO users (...) VALUES (...)") 51 | //! .fetch_one(&mut tx) 52 | //! .await 53 | //! .unwrap(); 54 | //! 55 | //! // `Tx` also implements `Deref` and `DerefMut` 56 | //! use sqlx::Acquire; 57 | //! let inner = tx.begin().await.unwrap(); 58 | //! /* ... */ 59 | //! } 60 | //! ``` 61 | //! 62 | //! ## Error handling 63 | //! 64 | //! `axum` requires that errors can be turned into responses. The [`Error`] type converts into a 65 | //! HTTP 500 response with the error message as the response body. This may be suitable for 66 | //! development or internal services but it's generally not advisable to return internal error 67 | //! details to clients. 68 | //! 69 | //! See [`Error`] for how to customise error handling. 70 | //! 71 | //! ## Multiple databases 72 | //! 73 | //! If you need to work with multiple databases, you can define marker structs for each. See 74 | //! [`Marker`] for an example. 75 | //! 76 | //! It's not currently possible to use `Tx` for a dynamic number of databases. Feel free to open an 77 | //! issue if you have a requirement for this. 78 | //! 79 | //! ## Accessing the pool 80 | //! 81 | //! Note that [`State`] implements [`FromRef`](axum_core::extract::FromRef) into the inner SQLx pool. Therefore, 82 | //! if you still need to access the database pool at some handler, you can use axum's `State` 83 | //! extractor normally. 84 | //! 85 | //! ``` 86 | //! use axum::extract::State; 87 | //! 88 | //! async fn this_still_works(State(pool): State) { 89 | //! /* ... */ 90 | //! } 91 | //! ``` 92 | //! 93 | //! # Examples 94 | //! 95 | //! See [`examples/`][examples] in the repo for more examples. 96 | //! 97 | //! [examples]: https://github.com/digital-society-coop/axum-sqlx-tx/tree/master/examples 98 | 99 | #![cfg_attr(doc, deny(warnings))] 100 | 101 | mod config; 102 | mod error; 103 | mod extension; 104 | mod layer; 105 | mod marker; 106 | mod state; 107 | mod tx; 108 | 109 | pub use crate::{ 110 | config::Config, 111 | error::Error, 112 | layer::{Layer, Service}, 113 | marker::Marker, 114 | state::State, 115 | tx::Tx, 116 | }; 117 | -------------------------------------------------------------------------------- /src/marker.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | 3 | /// Extractor marker type. 4 | /// 5 | /// Since the [`Tx`](crate::Tx) extractor operates at the type level, a generic type parameter is 6 | /// used to identify different databases. 7 | /// 8 | /// There is a blanket implementation for all implementors of [`sqlx::Database`], but you can create 9 | /// your own types if you need to work with multiple databases. 10 | /// 11 | /// ``` 12 | /// // Marker struct "database 1" 13 | /// #[derive(Debug)] 14 | /// struct Db1; 15 | /// 16 | /// impl axum_sqlx_tx::Marker for Db1 { 17 | /// type Driver = sqlx::Sqlite; 18 | /// } 19 | /// 20 | /// // Marker struct "database 2" 21 | /// #[derive(Debug)] 22 | /// struct Db2; 23 | /// 24 | /// impl axum_sqlx_tx::Marker for Db2 { 25 | /// type Driver = sqlx::Sqlite; 26 | /// } 27 | /// 28 | /// // You'll also need a "state" structure that implements `FromRef` for each `State` 29 | /// #[derive(Clone)] 30 | /// struct MyState { 31 | /// state1: axum_sqlx_tx::State, 32 | /// state2: axum_sqlx_tx::State, 33 | /// } 34 | /// 35 | /// impl axum::extract::FromRef for axum_sqlx_tx::State { 36 | /// fn from_ref(state: &MyState) -> Self { 37 | /// state.state1.clone() 38 | /// } 39 | /// } 40 | /// 41 | /// impl axum::extract::FromRef for axum_sqlx_tx::State { 42 | /// fn from_ref(state: &MyState) -> Self { 43 | /// state.state2.clone() 44 | /// } 45 | /// } 46 | /// 47 | /// // The extractor can then be aliased for each DB 48 | /// type Tx1 = axum_sqlx_tx::Tx; 49 | /// type Tx2 = axum_sqlx_tx::Tx; 50 | /// 51 | /// # async fn foo() { 52 | /// // Setup each extractor 53 | /// let pool1 = sqlx::SqlitePool::connect("...").await.unwrap(); 54 | /// let (state1, layer1) = Tx1::setup(pool1); 55 | /// 56 | /// let pool2 = sqlx::SqlitePool::connect("...").await.unwrap(); 57 | /// let (state2, layer2) = Tx2::setup(pool2); 58 | /// 59 | /// let app = axum::Router::new() 60 | /// .route("/", axum::routing::get(|tx1: Tx1, tx2: Tx2| async move { 61 | /// /* ... */ 62 | /// })) 63 | /// .layer(layer1) 64 | /// .layer(layer2) 65 | /// .with_state(MyState { state1, state2 }); 66 | /// # let listener: tokio::net::TcpListener = todo!(); 67 | /// # axum::serve(listener, app); 68 | /// # } 69 | /// ``` 70 | pub trait Marker: Debug + Send + Sized + 'static { 71 | /// The `sqlx` database driver. 72 | type Driver: sqlx::Database; 73 | } 74 | 75 | impl Marker for DB { 76 | type Driver = Self; 77 | } 78 | -------------------------------------------------------------------------------- /src/state.rs: -------------------------------------------------------------------------------- 1 | use axum_core::extract::FromRef; 2 | 3 | use crate::Marker; 4 | 5 | /// Application state that enables the [`Tx`] extractor. 6 | /// 7 | /// `State` must be provided to `Router`s in order to use the [`Tx`] extractor, or else attempting 8 | /// to use the `Router` will not compile. 9 | /// 10 | /// `State` is constructed via [`Tx::setup`](crate::Tx::setup) or 11 | /// [`Config::setup`](crate::Config::setup), which also return a middleware [`Layer`](crate::Layer). 12 | /// The state and the middleware together enable the [`Tx`] extractor to work. 13 | /// 14 | /// [`Tx`]: crate::Tx 15 | #[derive(Debug)] 16 | pub struct State { 17 | pool: sqlx::Pool, 18 | } 19 | 20 | impl State { 21 | pub(crate) fn new(pool: sqlx::Pool) -> Self { 22 | Self { pool } 23 | } 24 | 25 | pub(crate) async fn transaction( 26 | &self, 27 | ) -> Result, sqlx::Error> { 28 | self.pool.begin().await 29 | } 30 | } 31 | 32 | impl Clone for State { 33 | fn clone(&self) -> Self { 34 | Self { 35 | pool: self.pool.clone(), 36 | } 37 | } 38 | } 39 | 40 | impl FromRef> for sqlx::Pool { 41 | fn from_ref(input: &State) -> Self { 42 | input.pool.clone() 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/tx.rs: -------------------------------------------------------------------------------- 1 | //! A request extension that enables the [`Tx`](crate::Tx) extractor. 2 | 3 | use std::{fmt, marker::PhantomData}; 4 | 5 | use axum_core::{ 6 | extract::{FromRef, FromRequestParts}, 7 | response::IntoResponse, 8 | }; 9 | use futures_core::{future::BoxFuture, stream::BoxStream}; 10 | use http::request::Parts; 11 | use parking_lot::{lock_api::ArcMutexGuard, RawMutex}; 12 | 13 | use crate::{ 14 | extension::{Extension, LazyTransaction}, 15 | Config, Error, Marker, State, 16 | }; 17 | 18 | /// An `axum` extractor for a database transaction. 19 | /// 20 | /// `&mut Tx` implements [`sqlx::Executor`] so it can be used directly with [`sqlx::query()`] 21 | /// (and [`sqlx::query_as()`], the corresponding macros, etc.): 22 | /// 23 | /// ``` 24 | /// use axum_sqlx_tx::Tx; 25 | /// use sqlx::Sqlite; 26 | /// 27 | /// async fn handler(mut tx: Tx) -> Result<(), sqlx::Error> { 28 | /// sqlx::query("...").execute(&mut tx).await?; 29 | /// /* ... */ 30 | /// # Ok(()) 31 | /// } 32 | /// ``` 33 | /// 34 | /// It also implements `Deref` and `DerefMut`, so you can call 35 | /// methods from `Transaction` and its traits: 36 | /// 37 | /// ``` 38 | /// use axum_sqlx_tx::Tx; 39 | /// use sqlx::{Acquire as _, Sqlite}; 40 | /// 41 | /// async fn handler(mut tx: Tx) -> Result<(), sqlx::Error> { 42 | /// let inner = tx.begin().await?; 43 | /// /* ... */ 44 | /// # Ok(()) 45 | /// } 46 | /// ``` 47 | /// 48 | /// The `E` generic parameter controls the error type returned when the extractor fails. This can be 49 | /// used to configure the error response returned when the extractor fails: 50 | /// 51 | /// ``` 52 | /// use axum::response::IntoResponse; 53 | /// use axum_sqlx_tx::Tx; 54 | /// use sqlx::Sqlite; 55 | /// 56 | /// struct MyError(axum_sqlx_tx::Error); 57 | /// 58 | /// // The error type must implement From 59 | /// impl From for MyError { 60 | /// fn from(error: axum_sqlx_tx::Error) -> Self { 61 | /// Self(error) 62 | /// } 63 | /// } 64 | /// 65 | /// // The error type must implement IntoResponse 66 | /// impl IntoResponse for MyError { 67 | /// fn into_response(self) -> axum::response::Response { 68 | /// (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() 69 | /// } 70 | /// } 71 | /// 72 | /// async fn handler(tx: Tx) { 73 | /// /* ... */ 74 | /// } 75 | /// ``` 76 | pub struct Tx { 77 | tx: ArcMutexGuard>, 78 | _error: PhantomData, 79 | } 80 | 81 | impl Tx { 82 | /// Crate a [`State`] and [`Layer`](crate::Layer) to enable the extractor. 83 | /// 84 | /// This is convenient to use from a type alias, e.g. 85 | /// 86 | /// ``` 87 | /// # async fn foo() { 88 | /// type Tx = axum_sqlx_tx::Tx; 89 | /// 90 | /// let pool: sqlx::SqlitePool = todo!(); 91 | /// let (state, layer) = Tx::setup(pool); 92 | /// # } 93 | /// ``` 94 | pub fn setup(pool: sqlx::Pool) -> (State, crate::Layer) { 95 | Config::new(pool).setup() 96 | } 97 | 98 | /// Configure extractor behaviour. 99 | /// 100 | /// See the [`Config`] API for available options. 101 | /// 102 | /// This is convenient to use from a type alias, e.g. 103 | /// 104 | /// ``` 105 | /// # async fn foo() { 106 | /// type Tx = axum_sqlx_tx::Tx; 107 | /// 108 | /// # let pool: sqlx::SqlitePool = todo!(); 109 | /// let config = Tx::config(pool); 110 | /// # } 111 | /// ``` 112 | pub fn config(pool: sqlx::Pool) -> Config { 113 | Config::new(pool) 114 | } 115 | 116 | /// Explicitly commit the transaction. 117 | /// 118 | /// By default, the transaction will be committed when a successful response is returned 119 | /// (specifically, when the [`Service`](crate::Service) middleware intercepts an HTTP `2XX` or 120 | /// `3XX` response). This method allows the transaction to be committed explicitly. 121 | /// 122 | /// **Note:** trying to use the `Tx` extractor again after calling `commit` will currently 123 | /// generate [`Error::OverlappingExtractors`] errors. This may change in future. 124 | pub async fn commit(mut self) -> Result<(), sqlx::Error> { 125 | self.tx.commit().await 126 | } 127 | } 128 | 129 | impl fmt::Debug for Tx { 130 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 131 | f.debug_struct("Tx").finish_non_exhaustive() 132 | } 133 | } 134 | 135 | impl AsRef> for Tx { 136 | fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> { 137 | self.tx.as_ref() 138 | } 139 | } 140 | 141 | impl AsMut> for Tx { 142 | fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> { 143 | self.tx.as_mut() 144 | } 145 | } 146 | 147 | impl std::ops::Deref for Tx { 148 | type Target = sqlx::Transaction<'static, DB::Driver>; 149 | 150 | fn deref(&self) -> &Self::Target { 151 | self.tx.as_ref() 152 | } 153 | } 154 | 155 | impl std::ops::DerefMut for Tx { 156 | fn deref_mut(&mut self) -> &mut Self::Target { 157 | self.tx.as_mut() 158 | } 159 | } 160 | 161 | impl FromRequestParts for Tx 162 | where 163 | S: Sync, 164 | E: From + IntoResponse + Send, 165 | State: FromRef, 166 | { 167 | type Rejection = E; 168 | 169 | async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { 170 | let ext: &Extension = parts.extensions.get().ok_or(Error::MissingExtension)?; 171 | 172 | let tx = ext.acquire().await?; 173 | 174 | Ok(Self { 175 | tx, 176 | _error: PhantomData, 177 | }) 178 | } 179 | } 180 | 181 | impl<'c, DB, E> sqlx::Executor<'c> for &'c mut Tx 182 | where 183 | DB: Marker, 184 | for<'t> &'t mut ::Connection: 185 | sqlx::Executor<'t, Database = DB::Driver>, 186 | E: std::fmt::Debug + Send, 187 | { 188 | type Database = DB::Driver; 189 | 190 | #[allow(clippy::type_complexity)] 191 | fn fetch_many<'e, 'q: 'e, Q>( 192 | self, 193 | query: Q, 194 | ) -> BoxStream< 195 | 'e, 196 | Result< 197 | sqlx::Either< 198 | ::QueryResult, 199 | ::Row, 200 | >, 201 | sqlx::Error, 202 | >, 203 | > 204 | where 205 | 'c: 'e, 206 | Q: sqlx::Execute<'q, Self::Database> + 'q, 207 | { 208 | (&mut ***self).fetch_many(query) 209 | } 210 | 211 | fn fetch_optional<'e, 'q: 'e, Q>( 212 | self, 213 | query: Q, 214 | ) -> BoxFuture<'e, Result::Row>, sqlx::Error>> 215 | where 216 | 'c: 'e, 217 | Q: sqlx::Execute<'q, Self::Database> + 'q, 218 | { 219 | (&mut ***self).fetch_optional(query) 220 | } 221 | 222 | fn prepare_with<'e, 'q: 'e>( 223 | self, 224 | sql: &'q str, 225 | parameters: &'e [::TypeInfo], 226 | ) -> BoxFuture<'e, Result<::Statement<'q>, sqlx::Error>> 227 | where 228 | 'c: 'e, 229 | { 230 | (&mut ***self).prepare_with(sql, parameters) 231 | } 232 | 233 | fn describe<'e, 'q: 'e>( 234 | self, 235 | sql: &'q str, 236 | ) -> BoxFuture<'e, Result, sqlx::Error>> 237 | where 238 | 'c: 'e, 239 | { 240 | (&mut ***self).describe(sql) 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /tests/lib.rs: -------------------------------------------------------------------------------- 1 | use axum::{middleware, response::IntoResponse}; 2 | use axum_sqlx_tx::State; 3 | use sqlx::{sqlite::SqliteArguments, Arguments as _}; 4 | use tower::ServiceExt; 5 | 6 | type Tx = axum_sqlx_tx::Tx; 7 | 8 | #[tokio::test] 9 | async fn commit_on_success() { 10 | let (pool, response) = build_app(|mut tx: Tx| async move { 11 | let (_, name) = insert_user(&mut tx, 1, "huge hackerman").await; 12 | format!("hello {name}") 13 | }) 14 | .await; 15 | 16 | assert!(response.status.is_success()); 17 | assert_eq!(response.body, "hello huge hackerman"); 18 | 19 | let users: Vec<(i32, String)> = sqlx::query_as("SELECT * FROM users") 20 | .fetch_all(&pool) 21 | .await 22 | .unwrap(); 23 | assert_eq!(users, vec![(1, "huge hackerman".to_string())]); 24 | } 25 | 26 | #[tokio::test] 27 | async fn commit_on_redirection() { 28 | let (pool, response) = build_app(|mut tx: Tx| async move { 29 | let (_, _) = insert_user(&mut tx, 1, "john redirect").await; 30 | http::StatusCode::SEE_OTHER 31 | }) 32 | .await; 33 | 34 | assert!(response.status.is_redirection()); 35 | 36 | let users: Vec<(i32, String)> = sqlx::query_as("SELECT * FROM users") 37 | .fetch_all(&pool) 38 | .await 39 | .unwrap(); 40 | assert_eq!(users, vec![(1, "john redirect".to_string())]); 41 | } 42 | 43 | #[tokio::test] 44 | async fn rollback_on_error() { 45 | let (pool, response) = build_app(|mut tx: Tx| async move { 46 | insert_user(&mut tx, 1, "michael oxmaul").await; 47 | http::StatusCode::BAD_REQUEST 48 | }) 49 | .await; 50 | 51 | assert!(response.status.is_client_error()); 52 | assert!(response.body.is_empty()); 53 | 54 | assert_eq!(get_users(&pool).await, vec![]); 55 | } 56 | 57 | #[tokio::test] 58 | async fn explicit_commit() { 59 | let (pool, response) = build_app(|mut tx: Tx| async move { 60 | insert_user(&mut tx, 1, "michael oxmaul").await; 61 | tx.commit().await.unwrap(); 62 | http::StatusCode::BAD_REQUEST 63 | }) 64 | .await; 65 | 66 | assert!(response.status.is_client_error()); 67 | assert!(response.body.is_empty()); 68 | 69 | assert_eq!( 70 | get_users(&pool).await, 71 | vec![(1, "michael oxmaul".to_string())] 72 | ); 73 | } 74 | 75 | #[tokio::test] 76 | async fn extract_from_middleware_and_handler() { 77 | let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); 78 | 79 | sqlx::query("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY, name TEXT);") 80 | .execute(&pool) 81 | .await 82 | .unwrap(); 83 | 84 | async fn test_middleware( 85 | mut tx: Tx, 86 | req: http::Request, 87 | next: middleware::Next, 88 | ) -> impl IntoResponse { 89 | insert_user(&mut tx, 1, "bobby tables").await; 90 | 91 | // If we explicitly drop `tx` it should be consumable from the next handler. 92 | drop(tx); 93 | next.run(req).await 94 | } 95 | 96 | let (state, layer) = Tx::setup(pool); 97 | 98 | let app = axum::Router::new() 99 | .route( 100 | "/", 101 | axum::routing::get(|mut tx: Tx| async move { 102 | let users: Vec<(i32, String)> = sqlx::query_as("SELECT * FROM users") 103 | .fetch_all(&mut tx) 104 | .await 105 | .unwrap(); 106 | axum::Json(users) 107 | }), 108 | ) 109 | .layer(middleware::from_fn_with_state( 110 | state.clone(), 111 | test_middleware, 112 | )) 113 | .layer(layer) 114 | .with_state(state); 115 | 116 | let response = app 117 | .oneshot( 118 | http::Request::builder() 119 | .uri("/") 120 | .body(axum::body::Body::empty()) 121 | .unwrap(), 122 | ) 123 | .await 124 | .unwrap(); 125 | let status = response.status(); 126 | let body = axum::body::to_bytes(response.into_body(), usize::MAX) 127 | .await 128 | .unwrap(); 129 | 130 | assert!(status.is_success()); 131 | assert_eq!(body.as_ref(), b"[[1,\"bobby tables\"]]"); 132 | } 133 | 134 | #[tokio::test] 135 | async fn middleware_cloning_request_extensions() { 136 | let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); 137 | 138 | async fn test_middleware( 139 | req: http::Request, 140 | next: middleware::Next, 141 | ) -> impl IntoResponse { 142 | // Hold a clone of the request extensions 143 | let _extensions = req.extensions().clone(); 144 | 145 | next.run(req).await 146 | } 147 | 148 | let (state, layer) = Tx::setup(pool); 149 | 150 | let app = axum::Router::new() 151 | .route("/", axum::routing::get(|_tx: Tx| async move {})) 152 | .layer(middleware::from_fn_with_state( 153 | state.clone(), 154 | test_middleware, 155 | )) 156 | .layer(layer) 157 | .with_state(state); 158 | 159 | let response = app 160 | .oneshot( 161 | http::Request::builder() 162 | .uri("/") 163 | .body(axum::body::Body::empty()) 164 | .unwrap(), 165 | ) 166 | .await 167 | .unwrap(); 168 | let status = response.status(); 169 | let body = axum::body::to_bytes(response.into_body(), usize::MAX) 170 | .await 171 | .unwrap(); 172 | dbg!(body); 173 | 174 | assert!(status.is_success()); 175 | } 176 | 177 | #[tokio::test] 178 | async fn substates() { 179 | #[derive(Clone)] 180 | struct MyState { 181 | state: State, 182 | } 183 | 184 | impl axum_core::extract::FromRef for State { 185 | fn from_ref(state: &MyState) -> Self { 186 | state.state.clone() 187 | } 188 | } 189 | 190 | let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); 191 | 192 | let (state, layer) = Tx::setup(pool); 193 | 194 | let app = axum::Router::new() 195 | .route("/", axum::routing::get(|_: Tx| async move {})) 196 | .layer(layer) 197 | .with_state(MyState { state }); 198 | let response = app 199 | .oneshot( 200 | http::Request::builder() 201 | .uri("/") 202 | .body(axum::body::Body::empty()) 203 | .unwrap(), 204 | ) 205 | .await 206 | .unwrap(); 207 | 208 | assert!(response.status().is_success()); 209 | } 210 | 211 | #[tokio::test] 212 | async fn extract_pool_from_state() { 213 | let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); 214 | 215 | let (state, layer) = Tx::setup(pool); 216 | 217 | let app = axum::Router::new() 218 | .route( 219 | "/", 220 | axum::routing::get( 221 | |axum::extract::State(_pool): axum::extract::State| async move {}, 222 | ), 223 | ) 224 | .layer(layer) 225 | .with_state(state); 226 | 227 | let response = app 228 | .oneshot( 229 | http::Request::builder() 230 | .uri("/") 231 | .body(axum::body::Body::empty()) 232 | .unwrap(), 233 | ) 234 | .await 235 | .unwrap(); 236 | 237 | assert!(response.status().is_success()); 238 | } 239 | 240 | #[tokio::test] 241 | async fn missing_layer() { 242 | let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); 243 | 244 | // Note that we have to explicitly ignore the `_layer`, making it hard to do this accidentally. 245 | let (state, _layer) = Tx::setup(pool); 246 | 247 | let app = axum::Router::new() 248 | .route("/", axum::routing::get(|_: Tx| async move {})) 249 | .with_state(state); 250 | let response = app 251 | .oneshot( 252 | http::Request::builder() 253 | .uri("/") 254 | .body(axum::body::Body::empty()) 255 | .unwrap(), 256 | ) 257 | .await 258 | .unwrap(); 259 | 260 | assert!(response.status().is_server_error()); 261 | 262 | let body = axum::body::to_bytes(response.into_body(), usize::MAX) 263 | .await 264 | .unwrap(); 265 | assert_eq!(body, format!("{}", axum_sqlx_tx::Error::MissingExtension)); 266 | } 267 | 268 | #[tokio::test] 269 | async fn overlapping_extractors() { 270 | let (_, response) = build_app(|_: Tx, _: Tx| async move {}).await; 271 | 272 | assert!(response.status.is_server_error()); 273 | assert_eq!( 274 | response.body, 275 | format!("{}", axum_sqlx_tx::Error::OverlappingExtractors) 276 | ); 277 | } 278 | 279 | #[tokio::test] 280 | async fn extractor_error_override() { 281 | let (_, response) = 282 | build_app(|_: Tx, _: axum_sqlx_tx::Tx| async move {}).await; 283 | 284 | assert!(response.status.is_client_error()); 285 | assert_eq!(response.body, "internal server error"); 286 | } 287 | 288 | #[tokio::test] 289 | async fn layer_error_override() { 290 | let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); 291 | 292 | sqlx::query("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY);") 293 | .execute(&pool) 294 | .await 295 | .unwrap(); 296 | sqlx::query( 297 | r#" 298 | CREATE TABLE IF NOT EXISTS comments ( 299 | id INT PRIMARY KEY, 300 | user_id INT, 301 | FOREIGN KEY (user_id) REFERENCES users(id) DEFERRABLE INITIALLY DEFERRED 302 | );"#, 303 | ) 304 | .execute(&pool) 305 | .await 306 | .unwrap(); 307 | 308 | let (state, layer) = Tx::config(pool).layer_error::().setup(); 309 | 310 | let app = axum::Router::new() 311 | .route( 312 | "/", 313 | axum::routing::get(|mut tx: Tx| async move { 314 | sqlx::query("INSERT INTO comments VALUES (random(), random())") 315 | .execute(&mut tx) 316 | .await 317 | .unwrap(); 318 | }), 319 | ) 320 | .layer(layer) 321 | .with_state(state); 322 | 323 | let response = app 324 | .oneshot( 325 | http::Request::builder() 326 | .uri("/") 327 | .body(axum::body::Body::empty()) 328 | .unwrap(), 329 | ) 330 | .await 331 | .unwrap(); 332 | let status = response.status(); 333 | let body = axum::body::to_bytes(response.into_body(), usize::MAX) 334 | .await 335 | .unwrap(); 336 | 337 | assert!(status.is_client_error()); 338 | assert_eq!(body, "internal server error"); 339 | } 340 | 341 | #[tokio::test] 342 | async fn multi_db() { 343 | #[derive(Debug)] 344 | struct DbA; 345 | impl axum_sqlx_tx::Marker for DbA { 346 | type Driver = sqlx::Sqlite; 347 | } 348 | type TxA = axum_sqlx_tx::Tx; 349 | 350 | #[derive(Debug)] 351 | struct DbB; 352 | impl axum_sqlx_tx::Marker for DbB { 353 | type Driver = sqlx::Sqlite; 354 | } 355 | type TxB = axum_sqlx_tx::Tx; 356 | 357 | let pool_a = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); 358 | let pool_b = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); 359 | 360 | sqlx::query("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY);") 361 | .execute(&pool_a) 362 | .await 363 | .unwrap(); 364 | sqlx::query( 365 | r#" 366 | CREATE TABLE IF NOT EXISTS comments ( 367 | id INT PRIMARY KEY, 368 | user_id INT 369 | );"#, 370 | ) 371 | .execute(&pool_b) 372 | .await 373 | .unwrap(); 374 | 375 | let (state_a, layer_a) = TxA::setup(pool_a); 376 | let (state_b, layer_b) = TxB::setup(pool_b); 377 | 378 | #[derive(Clone)] 379 | struct State { 380 | state_a: axum_sqlx_tx::State, 381 | state_b: axum_sqlx_tx::State, 382 | } 383 | 384 | impl axum::extract::FromRef for axum_sqlx_tx::State { 385 | fn from_ref(input: &State) -> Self { 386 | input.state_a.clone() 387 | } 388 | } 389 | 390 | impl axum::extract::FromRef for axum_sqlx_tx::State { 391 | fn from_ref(input: &State) -> Self { 392 | input.state_b.clone() 393 | } 394 | } 395 | 396 | let app = axum::Router::new() 397 | .route( 398 | "/", 399 | axum::routing::get(|mut tx_a: TxA, mut tx_b: TxB| async move { 400 | sqlx::query("SELECT * FROM users") 401 | .execute(&mut tx_a) 402 | .await 403 | .unwrap(); 404 | sqlx::query("SELECT * FROM comments") 405 | .execute(&mut tx_b) 406 | .await 407 | .unwrap(); 408 | }), 409 | ) 410 | .layer(layer_a) 411 | .layer(layer_b) 412 | .with_state(State { state_a, state_b }); 413 | 414 | let response = app 415 | .oneshot( 416 | http::Request::builder() 417 | .uri("/") 418 | .body(axum::body::Body::empty()) 419 | .unwrap(), 420 | ) 421 | .await 422 | .unwrap(); 423 | let status = response.status(); 424 | 425 | assert!(status.is_success()); 426 | } 427 | 428 | async fn insert_user(tx: &mut Tx, id: i32, name: &str) -> (i32, String) { 429 | let mut args = SqliteArguments::default(); 430 | args.add(id).unwrap(); 431 | args.add(name).unwrap(); 432 | sqlx::query_as_with( 433 | r#"INSERT INTO users VALUES (?, ?) RETURNING id, name;"#, 434 | args, 435 | ) 436 | .fetch_one(tx) 437 | .await 438 | .unwrap() 439 | } 440 | 441 | async fn get_users(pool: &sqlx::SqlitePool) -> Vec<(i32, String)> { 442 | sqlx::query_as("SELECT * FROM users") 443 | .fetch_all(pool) 444 | .await 445 | .unwrap() 446 | } 447 | 448 | struct Response { 449 | status: http::StatusCode, 450 | body: axum::body::Bytes, 451 | } 452 | 453 | async fn build_app(handler: H) -> (sqlx::SqlitePool, Response) 454 | where 455 | H: axum::handler::Handler>, 456 | T: 'static, 457 | { 458 | let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); 459 | 460 | sqlx::query("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY, name TEXT);") 461 | .execute(&pool) 462 | .await 463 | .unwrap(); 464 | 465 | let (state, layer) = Tx::setup(pool.clone()); 466 | 467 | let app = axum::Router::new() 468 | .route("/", axum::routing::get(handler)) 469 | .layer(layer) 470 | .with_state(state); 471 | 472 | let response = app 473 | .oneshot( 474 | http::Request::builder() 475 | .uri("/") 476 | .body(axum::body::Body::empty()) 477 | .unwrap(), 478 | ) 479 | .await 480 | .unwrap(); 481 | let status = response.status(); 482 | let body = axum::body::to_bytes(response.into_body(), usize::MAX) 483 | .await 484 | .unwrap(); 485 | 486 | (pool, Response { status, body }) 487 | } 488 | 489 | struct MyExtractorError { 490 | _0: axum_sqlx_tx::Error, 491 | } 492 | 493 | impl From for MyExtractorError { 494 | fn from(error: axum_sqlx_tx::Error) -> Self { 495 | Self { _0: error } 496 | } 497 | } 498 | 499 | impl IntoResponse for MyExtractorError { 500 | fn into_response(self) -> axum::response::Response { 501 | (http::StatusCode::IM_A_TEAPOT, "internal server error").into_response() 502 | } 503 | } 504 | 505 | struct MyLayerError { 506 | _0: sqlx::Error, 507 | } 508 | 509 | impl From for MyLayerError { 510 | fn from(error: sqlx::Error) -> Self { 511 | Self { _0: error } 512 | } 513 | } 514 | 515 | impl IntoResponse for MyLayerError { 516 | fn into_response(self) -> axum::response::Response { 517 | (http::StatusCode::IM_A_TEAPOT, "internal server error").into_response() 518 | } 519 | } 520 | --------------------------------------------------------------------------------