├── .envrc ├── .github └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── examples └── customize_connection.rs ├── renovate.json ├── src ├── async_traits.rs ├── connection.rs ├── connection_manager.rs ├── error.rs └── lib.rs └── tests ├── README.md └── test.rs /.envrc: -------------------------------------------------------------------------------- 1 | # For use with direnv: https://direnv.net 2 | # See also: ./env.sh 3 | 4 | PATH_add out/cockroachdb/bin 5 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | # 2 | # Configuration for GitHub-based CI, based on the stock GitHub Rust config. 3 | # 4 | name: Rust 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | check-style: 14 | runs-on: ubuntu-latest 15 | steps: 16 | # actions/checkout@v2 17 | - uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 18 | - uses: actions-rs/toolchain@88dc2356392166efad76775c878094f4e83ff746 19 | with: 20 | toolchain: stable 21 | default: false 22 | components: rustfmt 23 | - name: Check style 24 | run: cargo fmt -- --check 25 | 26 | check-without-cockroach: 27 | runs-on: ubuntu-latest 28 | steps: 29 | # actions/checkout@v2 30 | - uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 31 | - name: Cargo check 32 | run: cargo check --no-default-features 33 | 34 | build-and-test: 35 | runs-on: ${{ matrix.os }} 36 | strategy: 37 | matrix: 38 | os: [ ubuntu-latest, macos-12 ] 39 | steps: 40 | # actions/checkout@v2 41 | - uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 42 | - name: Build 43 | run: cargo build --tests --verbose 44 | - name: Run tests 45 | run: cargo test --verbose 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | out 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "async-bb8-diesel" 3 | description = "async bb8 connection manager for Diesel" 4 | version = "0.2.1" 5 | authors = ["Sean Klein "] 6 | edition = "2018" 7 | license = "MIT" 8 | repository = "https://github.com/oxidecomputer/async-bb8-diesel" 9 | keywords = ["diesel", "r2d2", "pool", "tokio", "async"] 10 | 11 | [features] 12 | # Enables CockroachDB-specific functions. 13 | cockroach = [] 14 | default = [ "cockroach" ] 15 | 16 | [dependencies] 17 | bb8 = "0.8" 18 | async-trait = "0.1.81" 19 | diesel = { version = "2.2.2", default-features = false, features = [ "r2d2" ] } 20 | futures = "0.3" 21 | thiserror = "1.0" 22 | tokio = { version = "1.32", default-features = false, features = [ "rt-multi-thread" ] } 23 | 24 | [dev-dependencies] 25 | anyhow = "1.0" 26 | crdb-harness = "0.0.1" 27 | diesel = { version = "2.2.2", features = [ "postgres", "r2d2" ] } 28 | libc = "0.2.154" 29 | tempfile = "3.8" 30 | tokio = { version = "1.32", features = [ "macros", "fs", "process" ] } 31 | tokio-postgres = { version = "0.7", features = [ "with-chrono-0_4", "with-uuid-1" ] } 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Oxide Computer Company 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # async-bb8-diesel 2 | 3 | This crate provides an interface for asynchronously accessing a bb8 connection 4 | pool atop [Diesel](https://github.com/diesel-rs/diesel). 5 | 6 | This is accomplished by implementing an async version 7 | of Diesel's "RunQueryDsl" trait, aptly named "AsyncRunQueryDsl", 8 | which operates on an async-compatible connection. When called 9 | from an async context, these operations transfer the query 10 | to a blocking tokio thread, where it may be executed. 11 | 12 | NOTE: This crate pre-dated [diesel-async](https://docs.rs/diesel-async/). 13 | For new code, consider using that interface directly. 14 | 15 | # Pre-requisites 16 | 17 | - A willingness to tolerate some instability. This crate effectively originated 18 | as a stop-gap until more native asynchronous support existed within Diesel. 19 | 20 | # Comparisons with existing crates 21 | 22 | This crate was heavily inspired by both 23 | [tokio-diesel](https://github.com/mehcode/tokio-diesel) and 24 | [bb8-diesel](https://github.com/overdrivenpotato/bb8-diesel), but serves a 25 | slightly different purpose. 26 | 27 | ## What do those crates do? 28 | 29 | Both of those crates rely heavily on the 30 | [`tokio::block_in_place`](https://docs.rs/tokio/1.10.1/tokio/task/fn.block_in_place.html) 31 | function to actually execute synchronous Diesel queries. 32 | 33 | Their flow is effectively: 34 | - A query is issued (in the case of tokio-diesel, it's async. In the case 35 | of bb8-diesel, it's synchronous - but you're using bb8, so presumably 36 | calling from an asynchronous task). 37 | - The query and connection to the DB are moved into the `block_in_place` call. 38 | - Diesel's native synchronous API is used within `block_in_place`. 39 | 40 | These crates have some advantages by taking this approach: 41 | - The tokio executor knows not to schedule additional async tasks for the 42 | duration of the `block_in_place` call. 43 | - The callback used within `block_in_place` doesn't need to be `Send` - it 44 | executes synchronously within the otherwise asynchronous task. 45 | 46 | However, they also have some downsides: 47 | - The call to `block_in_place` effectively pauses an async thread for the 48 | duration of the call. This *requires* a multi-threaded runtime, and reduces 49 | efficacy of one of these threads for the duration of the call. 50 | - The call to `block_in_place` starves all other asynchronous code running in 51 | the same task. 52 | 53 | This starvation results in some subtle inhibition of other futures, such as in 54 | the following example, where a timeout would be ignored if a long-running 55 | database operation was issued from the same task as a timeout. 56 | 57 | ```rust 58 | tokio::select! { 59 | // Calls "tokio::block_in_place", doing a synchronous Diesel operation 60 | // on the calling thread... 61 | _ = perform_database_operation() => {}, 62 | // ... meaning this asynchronous timeout cannot complete! 63 | _ = sleep_until(timeout) = {}, 64 | } 65 | ``` 66 | 67 | ## What does this crate do? 68 | 69 | This crate attempts to avoid calls to `block_in_place` - which would block the 70 | calling thread - and prefers to use 71 | [`tokio::spawn_blocking`](https://docs.rs/tokio/1.10.1/tokio/task/fn.spawn_blocking.html) 72 | function. This function moves the requested operation to an entirely distinct 73 | thread where blocking is acceptable, but does *not* prevent the current task 74 | from executing other asynchronous work. 75 | 76 | This isn't entirely free - as this work now needs to be transferred to a new 77 | thread, it imposes a "Send + 'static" constraint on the queries which are 78 | constructed. 79 | 80 | ## Which one is right for me? 81 | 82 | - If you care about preserving typically expected semantics for asynchronous 83 | operations, we recommend this crate. 84 | - If you don't - maybe you have an asynchronous workload, but you *know* 85 | when you can block those threads - you can use either of the 86 | tokio-diesel or bb8-diesel crates, depending on whether or not you 87 | want access to the asynchronous thread pool. 88 | -------------------------------------------------------------------------------- /examples/customize_connection.rs: -------------------------------------------------------------------------------- 1 | //! An example showing how to cutomize connections while using pooling. 2 | 3 | use async_bb8_diesel::{AsyncSimpleConnection, Connection, ConnectionError}; 4 | use async_trait::async_trait; 5 | use diesel::pg::PgConnection; 6 | 7 | #[derive(Debug)] 8 | struct ConnectionCustomizer {} 9 | 10 | type DieselPgConn = Connection; 11 | 12 | #[async_trait] 13 | impl bb8::CustomizeConnection for ConnectionCustomizer { 14 | async fn on_acquire(&self, connection: &mut DieselPgConn) -> Result<(), ConnectionError> { 15 | connection 16 | .batch_execute_async("please execute some raw sql for me") 17 | .await 18 | .map_err(ConnectionError::from) 19 | } 20 | } 21 | 22 | #[tokio::main] 23 | async fn main() { 24 | let manager = async_bb8_diesel::ConnectionManager::::new("localhost:1234"); 25 | let _ = bb8::Pool::builder() 26 | .connection_customizer(Box::new(ConnectionCustomizer {})) 27 | .build(manager) 28 | .await 29 | .unwrap(); 30 | } 31 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "local>oxidecomputer/renovate-config" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /src/async_traits.rs: -------------------------------------------------------------------------------- 1 | //! Async versions of traits for issuing Diesel queries. 2 | 3 | use crate::connection::Connection; 4 | use async_trait::async_trait; 5 | use diesel::{ 6 | connection::{ 7 | Connection as DieselConnection, SimpleConnection, TransactionManager, 8 | TransactionManagerStatus, 9 | }, 10 | dsl::Limit, 11 | query_dsl::{ 12 | methods::{ExecuteDsl, LimitDsl, LoadQuery}, 13 | RunQueryDsl, 14 | }, 15 | r2d2::R2D2Connection, 16 | result::Error as DieselError, 17 | }; 18 | use futures::future::BoxFuture; 19 | use futures::future::FutureExt; 20 | use std::any::Any; 21 | use std::future::Future; 22 | use std::sync::Arc; 23 | use std::sync::MutexGuard; 24 | use tokio::task::spawn_blocking; 25 | 26 | /// An async variant of [`diesel::connection::SimpleConnection`]. 27 | #[async_trait] 28 | pub trait AsyncSimpleConnection 29 | where 30 | Conn: 'static + SimpleConnection, 31 | { 32 | async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>; 33 | } 34 | 35 | #[cfg(feature = "cockroach")] 36 | fn retryable_error(err: &DieselError) -> bool { 37 | use diesel::result::DatabaseErrorKind::SerializationFailure; 38 | match err { 39 | DieselError::DatabaseError(SerializationFailure, _boxed_error_information) => true, 40 | _ => false, 41 | } 42 | } 43 | 44 | /// An async variant of [`diesel::r2d2::R2D2Connection`]. 45 | #[async_trait] 46 | pub trait AsyncR2D2Connection: AsyncConnection 47 | where 48 | Conn: 'static + DieselConnection + R2D2Connection, 49 | Self: Send + Sized + 'static, 50 | { 51 | async fn ping_async(&mut self) -> diesel::result::QueryResult<()> { 52 | self.as_async_conn().run(|conn| conn.ping()).await 53 | } 54 | 55 | async fn is_broken_async(&mut self) -> bool { 56 | self.as_async_conn() 57 | .run(|conn| Ok::(conn.is_broken())) 58 | .await 59 | .unwrap() 60 | } 61 | } 62 | 63 | /// An async variant of [`diesel::connection::Connection`]. 64 | #[async_trait] 65 | pub trait AsyncConnection: AsyncSimpleConnection 66 | where 67 | Conn: 'static + DieselConnection, 68 | Self: Send + Sized + 'static, 69 | { 70 | #[doc(hidden)] 71 | fn get_owned_connection(&self) -> Self; 72 | #[doc(hidden)] 73 | fn as_sync_conn(&self) -> MutexGuard<'_, Conn>; 74 | #[doc(hidden)] 75 | fn as_async_conn(&self) -> &Connection; 76 | 77 | /// Runs the function `f` in an context where blocking is safe. 78 | async fn run(&self, f: Func) -> Result 79 | where 80 | R: Send + 'static, 81 | E: Send + 'static, 82 | Func: FnOnce(&mut Conn) -> Result + Send + 'static, 83 | { 84 | let connection = self.get_owned_connection(); 85 | connection.run_with_connection(f).await 86 | } 87 | 88 | #[doc(hidden)] 89 | async fn run_with_connection(self, f: Func) -> Result 90 | where 91 | R: Send + 'static, 92 | E: Send + 'static, 93 | Func: FnOnce(&mut Conn) -> Result + Send + 'static, 94 | { 95 | spawn_blocking(move || f(&mut *self.as_sync_conn())) 96 | .await 97 | .unwrap() // Propagate panics 98 | } 99 | 100 | #[doc(hidden)] 101 | async fn run_with_shared_connection(self: &Arc, f: Func) -> Result 102 | where 103 | R: Send + 'static, 104 | E: Send + 'static, 105 | Func: FnOnce(&mut Conn) -> Result + Send + 'static, 106 | { 107 | let conn = self.clone(); 108 | spawn_blocking(move || f(&mut *conn.as_sync_conn())) 109 | .await 110 | .unwrap() // Propagate panics 111 | } 112 | 113 | #[doc(hidden)] 114 | async fn transaction_depth(&self) -> Result { 115 | let conn = self.get_owned_connection(); 116 | 117 | Self::run_with_connection(conn, |conn| { 118 | match Conn::TransactionManager::transaction_manager_status_mut(&mut *conn) { 119 | TransactionManagerStatus::Valid(status) => { 120 | Ok(status.transaction_depth().map(|d| d.into()).unwrap_or(0)) 121 | } 122 | TransactionManagerStatus::InError => Err(DieselError::BrokenTransactionManager), 123 | } 124 | }) 125 | .await 126 | } 127 | 128 | // Diesel's "begin_transaction" chooses whether to issue "BEGIN" or a 129 | // "SAVEPOINT" depending on the transaction depth. 130 | // 131 | // This method is a wrapper around that call, with validation that 132 | // we're actually issuing the BEGIN statement here. 133 | #[doc(hidden)] 134 | async fn start_transaction(self: &Arc) -> Result<(), DieselError> { 135 | if self.transaction_depth().await? != 0 { 136 | return Err(DieselError::AlreadyInTransaction); 137 | } 138 | self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) 139 | .await?; 140 | Ok(()) 141 | } 142 | 143 | // Diesel's "begin_transaction" chooses whether to issue "BEGIN" or a 144 | // "SAVEPOINT" depending on the transaction depth. 145 | // 146 | // This method is a wrapper around that call, with validation that 147 | // we're actually issuing our first SAVEPOINT here. 148 | #[doc(hidden)] 149 | async fn add_retry_savepoint(self: &Arc) -> Result<(), DieselError> { 150 | match self.transaction_depth().await? { 151 | 0 => return Err(DieselError::NotInTransaction), 152 | 1 => (), 153 | _ => return Err(DieselError::AlreadyInTransaction), 154 | }; 155 | 156 | self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) 157 | .await?; 158 | Ok(()) 159 | } 160 | 161 | #[doc(hidden)] 162 | async fn commit_transaction(self: &Arc) -> Result<(), DieselError> { 163 | self.run_with_shared_connection(|conn| Conn::TransactionManager::commit_transaction(conn)) 164 | .await?; 165 | Ok(()) 166 | } 167 | 168 | #[doc(hidden)] 169 | async fn rollback_transaction(self: &Arc) -> Result<(), DieselError> { 170 | self.run_with_shared_connection(|conn| { 171 | Conn::TransactionManager::rollback_transaction(conn) 172 | }) 173 | .await?; 174 | Ok(()) 175 | } 176 | 177 | /// Issues a function `f` as a transaction. 178 | /// 179 | /// If it fails, asynchronously calls `retry` to decide if to retry. 180 | /// 181 | /// This function throws an error if it is called from within an existing 182 | /// transaction. 183 | #[cfg(feature = "cockroach")] 184 | async fn transaction_async_with_retry( 185 | &'a self, 186 | f: Func, 187 | retry: RetryFunc, 188 | ) -> Result 189 | where 190 | R: Any + Send + 'static, 191 | Fut: FutureExt> + Send, 192 | Func: (Fn(Connection) -> Fut) + Send + Sync, 193 | RetryFut: FutureExt + Send, 194 | RetryFunc: Fn() -> RetryFut + Send + Sync, 195 | { 196 | // This function sure has a bunch of generic parameters, which can cause 197 | // a lot of code to be generated, and can slow down compile-time. 198 | // 199 | // This API intends to provide a convenient, generic, shim over the 200 | // dynamic "transaction_async_with_retry_inner" function below, which 201 | // should avoid being generic. The goal here is to instantiate only one 202 | // significant "body" of this function, while retaining flexibility for 203 | // clients using this library. 204 | 205 | // Box the functions, and box the return value. 206 | let f = |conn| { 207 | f(conn) 208 | .map(|result| result.map(|r| Box::new(r) as Box)) 209 | .boxed() 210 | }; 211 | let retry = || retry().boxed(); 212 | 213 | // Call the dynamically dispatched function, then retrieve the return 214 | // value out of a Box. 215 | self.transaction_async_with_retry_inner(&f, &retry) 216 | .await 217 | .map(|v| *v.downcast::().expect("Should be an 'R' type")) 218 | } 219 | 220 | // NOTE: This function intentionally avoids all generics! 221 | #[cfg(feature = "cockroach")] 222 | async fn transaction_async_with_retry_inner( 223 | &self, 224 | f: &(dyn Fn(Connection) -> BoxFuture<'_, Result, DieselError>> 225 | + Send 226 | + Sync), 227 | retry: &(dyn Fn() -> BoxFuture<'_, bool> + Send + Sync), 228 | ) -> Result, DieselError> { 229 | // Check out a connection once, and use it for the duration of the 230 | // operation. 231 | let conn = Arc::new(self.get_owned_connection()); 232 | 233 | // Refer to CockroachDB's guide on advanced client-side transaction 234 | // retries for the full context: 235 | // https://www.cockroachlabs.com/docs/v23.1/advanced-client-side-transaction-retries 236 | // 237 | // In short, they expect a particular name for this savepoint, but 238 | // Diesel has Opinions on savepoint names, so we use this session 239 | // variable to identify that any name is valid. 240 | // 241 | // TODO: It may be preferable to set this once per connection -- but 242 | // that'll require more interaction with how sessions with the database 243 | // are constructed. 244 | Self::start_transaction(&conn).await?; 245 | conn.run_with_shared_connection(|conn| { 246 | conn.batch_execute("SET LOCAL force_savepoint_restart = true") 247 | }) 248 | .await?; 249 | 250 | loop { 251 | // Add a SAVEPOINT to which we can later return. 252 | Self::add_retry_savepoint(&conn).await?; 253 | 254 | let async_conn = Connection(Self::as_async_conn(&conn).0.clone()); 255 | match f(async_conn).await { 256 | Ok(value) => { 257 | // The user-level operation succeeded: try to commit the 258 | // transaction by RELEASE-ing the retry savepoint. 259 | if let Err(err) = Self::commit_transaction(&conn).await { 260 | // Diesel's implementation of "commit_transaction" 261 | // calls "rollback_transaction" in the error path. 262 | // 263 | // We're still in the transaction, but we at least 264 | // tried to ROLLBACK to our savepoint. 265 | if !retryable_error(&err) || !retry().await { 266 | // Bail: ROLLBACK the initial BEGIN statement too. 267 | let _ = Self::rollback_transaction(&conn).await; 268 | return Err(err); 269 | } 270 | // ROLLBACK happened, we want to retry. 271 | continue; 272 | } 273 | 274 | // Commit the top-level transaction too. 275 | Self::commit_transaction(&conn).await?; 276 | return Ok(value); 277 | } 278 | Err(user_error) => { 279 | // The user-level operation failed: ROLLBACK to the retry 280 | // savepoint. 281 | if let Err(first_rollback_err) = Self::rollback_transaction(&conn).await { 282 | // If we fail while rolling back, prioritize returning 283 | // the ROLLBACK error over the user errors. 284 | return match Self::rollback_transaction(&conn).await { 285 | Ok(()) => Err(first_rollback_err), 286 | Err(second_rollback_err) => Err(second_rollback_err), 287 | }; 288 | } 289 | 290 | // We rolled back to the retry savepoint, and now want to 291 | // retry. 292 | if retryable_error(&user_error) && retry().await { 293 | continue; 294 | } 295 | 296 | // If we aren't retrying, ROLLBACK the BEGIN statement too. 297 | return match Self::rollback_transaction(&conn).await { 298 | Ok(()) => Err(user_error), 299 | Err(err) => Err(err), 300 | }; 301 | } 302 | } 303 | } 304 | } 305 | 306 | async fn transaction_async(&'a self, f: Func) -> Result 307 | where 308 | R: Send + 'static, 309 | E: From + Send + 'static, 310 | Fut: Future> + Send, 311 | Func: FnOnce(Connection) -> Fut + Send, 312 | { 313 | // This function sure has a bunch of generic parameters, which can cause 314 | // a lot of code to be generated, and can slow down compile-time. 315 | // 316 | // This API intends to provide a convenient, generic, shim over the 317 | // dynamic "transaction_async_with_retry_inner" function below, which 318 | // should avoid being generic. The goal here is to instantiate only one 319 | // significant "body" of this function, while retaining flexibility for 320 | // clients using this library. 321 | 322 | let f = Box::new(move |conn| { 323 | f(conn) 324 | .map(|result| result.map(|r| Box::new(r) as Box)) 325 | .boxed() 326 | }); 327 | 328 | self.transaction_async_inner(f) 329 | .await 330 | .map(|v| *v.downcast::().expect("Should be an 'R' type")) 331 | } 332 | 333 | // NOTE: This function intentionally avoids as many generic parameters as possible 334 | async fn transaction_async_inner<'a, E>( 335 | &'a self, 336 | f: Box< 337 | dyn FnOnce(Connection) -> BoxFuture<'a, Result, E>> 338 | + Send 339 | + 'a, 340 | >, 341 | ) -> Result, E> 342 | where 343 | E: From + Send + 'static, 344 | { 345 | // Check out a connection once, and use it for the duration of the 346 | // operation. 347 | let conn = Arc::new(self.get_owned_connection()); 348 | 349 | // This function mimics the implementation of: 350 | // https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction 351 | // 352 | // However, it modifies all callsites to instead issue 353 | // known-to-be-synchronous operations from an asynchronous context. 354 | conn.run_with_shared_connection(|conn| { 355 | Conn::TransactionManager::begin_transaction(conn).map_err(E::from) 356 | }) 357 | .await?; 358 | 359 | // TODO: The ideal interface would pass the "async_conn" object to the 360 | // underlying function "f" by reference. 361 | // 362 | // This would prevent the user-supplied closure + future from using the 363 | // connection *beyond* the duration of the transaction, which would be 364 | // bad. 365 | // 366 | // However, I'm struggling to get these lifetimes to work properly. If 367 | // you can figure out a way to convince that the reference lives long 368 | // enough to be referenceable by a Future, but short enough that we can 369 | // guarantee it doesn't live persist after this function returns, feel 370 | // free to make that change. 371 | let async_conn = Connection(Self::as_async_conn(&conn).0.clone()); 372 | match f(async_conn).await { 373 | Ok(value) => { 374 | conn.run_with_shared_connection(|conn| { 375 | Conn::TransactionManager::commit_transaction(conn).map_err(E::from) 376 | }) 377 | .await?; 378 | Ok(value) 379 | } 380 | Err(user_error) => { 381 | match conn 382 | .run_with_shared_connection(|conn| { 383 | Conn::TransactionManager::rollback_transaction(conn).map_err(E::from) 384 | }) 385 | .await 386 | { 387 | Ok(()) => Err(user_error), 388 | Err(err) => Err(err), 389 | } 390 | } 391 | } 392 | } 393 | } 394 | 395 | /// An async variant of [`diesel::query_dsl::RunQueryDsl`]. 396 | #[async_trait] 397 | pub trait AsyncRunQueryDsl 398 | where 399 | Conn: 'static + DieselConnection, 400 | { 401 | async fn execute_async(self, asc: &AsyncConn) -> Result 402 | where 403 | Self: ExecuteDsl; 404 | 405 | async fn load_async(self, asc: &AsyncConn) -> Result, DieselError> 406 | where 407 | U: Send + 'static, 408 | Self: LoadQuery<'static, Conn, U>; 409 | 410 | async fn get_result_async(self, asc: &AsyncConn) -> Result 411 | where 412 | U: Send + 'static, 413 | Self: LoadQuery<'static, Conn, U>; 414 | 415 | async fn get_results_async(self, asc: &AsyncConn) -> Result, DieselError> 416 | where 417 | U: Send + 'static, 418 | Self: LoadQuery<'static, Conn, U>; 419 | 420 | async fn first_async(self, asc: &AsyncConn) -> Result 421 | where 422 | U: Send + 'static, 423 | Self: LimitDsl, 424 | Limit: LoadQuery<'static, Conn, U>; 425 | } 426 | 427 | #[async_trait] 428 | impl AsyncRunQueryDsl for T 429 | where 430 | T: 'static + Send + RunQueryDsl, 431 | Conn: 'static + DieselConnection, 432 | AsyncConn: Send + Sync + AsyncConnection, 433 | { 434 | async fn execute_async(self, asc: &AsyncConn) -> Result 435 | where 436 | Self: ExecuteDsl, 437 | { 438 | asc.run(|conn| self.execute(conn)).await 439 | } 440 | 441 | async fn load_async(self, asc: &AsyncConn) -> Result, DieselError> 442 | where 443 | U: Send + 'static, 444 | Self: LoadQuery<'static, Conn, U>, 445 | { 446 | asc.run(|conn| self.load(conn)).await 447 | } 448 | 449 | async fn get_result_async(self, asc: &AsyncConn) -> Result 450 | where 451 | U: Send + 'static, 452 | Self: LoadQuery<'static, Conn, U>, 453 | { 454 | asc.run(|conn| self.get_result(conn)).await 455 | } 456 | 457 | async fn get_results_async(self, asc: &AsyncConn) -> Result, DieselError> 458 | where 459 | U: Send + 'static, 460 | Self: LoadQuery<'static, Conn, U>, 461 | { 462 | asc.run(|conn| self.get_results(conn)).await 463 | } 464 | 465 | async fn first_async(self, asc: &AsyncConn) -> Result 466 | where 467 | U: Send + 'static, 468 | Self: LimitDsl, 469 | Limit: LoadQuery<'static, Conn, U>, 470 | { 471 | asc.run(|conn| self.first(conn)).await 472 | } 473 | } 474 | 475 | #[async_trait] 476 | pub trait AsyncSaveChangesDsl 477 | where 478 | Conn: 'static + DieselConnection, 479 | { 480 | async fn save_changes_async(self, asc: &AsyncConn) -> Result 481 | where 482 | Self: Sized, 483 | Conn: diesel::query_dsl::UpdateAndFetchResults, 484 | Output: Send + 'static; 485 | } 486 | 487 | #[async_trait] 488 | impl AsyncSaveChangesDsl for T 489 | where 490 | T: 'static + Send + Sync + diesel::SaveChangesDsl, 491 | Conn: 'static + DieselConnection, 492 | AsyncConn: Send + Sync + AsyncConnection, 493 | { 494 | async fn save_changes_async(self, asc: &AsyncConn) -> Result 495 | where 496 | Conn: diesel::query_dsl::UpdateAndFetchResults, 497 | Output: Send + 'static, 498 | { 499 | asc.run(|conn| self.save_changes(conn)).await 500 | } 501 | } 502 | -------------------------------------------------------------------------------- /src/connection.rs: -------------------------------------------------------------------------------- 1 | //! An async wrapper around a [`diesel::Connection`]. 2 | 3 | use async_trait::async_trait; 4 | use diesel::r2d2::R2D2Connection; 5 | use std::sync::{Arc, Mutex, MutexGuard}; 6 | use tokio::task; 7 | 8 | /// An async-safe analogue of any connection that implements 9 | /// [`diesel::Connection`]. 10 | /// 11 | /// These connections are created by [`crate::ConnectionManager`]. 12 | /// 13 | /// All blocking methods within this type delegate to 14 | /// [`tokio::task::spawn_blocking`], meaning they won't block 15 | /// any asynchronous work or threads. 16 | pub struct Connection(pub(crate) Arc>); 17 | 18 | impl Connection { 19 | pub fn new(c: C) -> Self { 20 | Self(Arc::new(Mutex::new(c))) 21 | } 22 | 23 | // Accesses the underlying connection. 24 | // 25 | // As this is a blocking mutex, it's recommended to avoid invoking 26 | // this function from an asynchronous context. 27 | pub(crate) fn inner(&self) -> MutexGuard<'_, C> { 28 | self.0.lock().unwrap() 29 | } 30 | } 31 | 32 | #[async_trait] 33 | impl crate::AsyncSimpleConnection for Connection 34 | where 35 | Conn: 'static + R2D2Connection, 36 | { 37 | #[inline] 38 | async fn batch_execute_async(&self, query: &str) -> Result<(), diesel::result::Error> { 39 | let diesel_conn = Connection(self.0.clone()); 40 | let query = query.to_string(); 41 | task::spawn_blocking(move || diesel_conn.inner().batch_execute(&query)) 42 | .await 43 | .unwrap() // Propagate panics 44 | } 45 | } 46 | 47 | #[async_trait] 48 | impl crate::AsyncR2D2Connection for Connection where Conn: 'static + R2D2Connection 49 | {} 50 | 51 | #[async_trait] 52 | impl crate::AsyncConnection for Connection 53 | where 54 | Conn: 'static + R2D2Connection, 55 | Connection: crate::AsyncSimpleConnection, 56 | { 57 | fn get_owned_connection(&self) -> Self { 58 | Connection(self.0.clone()) 59 | } 60 | 61 | // Accesses the connection synchronously, protected by a mutex. 62 | // 63 | // Avoid calling from asynchronous contexts. 64 | fn as_sync_conn(&self) -> MutexGuard<'_, Conn> { 65 | self.inner() 66 | } 67 | 68 | fn as_async_conn(&self) -> &Connection { 69 | self 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/connection_manager.rs: -------------------------------------------------------------------------------- 1 | //! An async-safe connection pool for Diesel. 2 | 3 | use crate::{Connection, ConnectionError}; 4 | use async_trait::async_trait; 5 | use diesel::r2d2::{self, ManageConnection, R2D2Connection}; 6 | use std::sync::{Arc, Mutex}; 7 | 8 | /// A connection manager which implements [`bb8::ManageConnection`] to 9 | /// integrate with bb8. 10 | /// 11 | /// ```no_run 12 | /// use async_bb8_diesel::AsyncRunQueryDsl; 13 | /// use diesel::prelude::*; 14 | /// use diesel::pg::PgConnection; 15 | /// 16 | /// table! { 17 | /// users (id) { 18 | /// id -> Integer, 19 | /// } 20 | /// } 21 | /// 22 | /// #[tokio::main] 23 | /// async fn main() { 24 | /// use users::dsl; 25 | /// 26 | /// // Creates a Diesel-specific connection manager for bb8. 27 | /// let mgr = async_bb8_diesel::ConnectionManager::::new("localhost:1234"); 28 | /// let pool = bb8::Pool::builder().build(mgr).await.unwrap(); 29 | /// 30 | /// diesel::insert_into(dsl::users) 31 | /// .values(dsl::id.eq(1337)) 32 | /// .execute_async(&*pool.get().await.unwrap()) 33 | /// .await 34 | /// .unwrap(); 35 | /// } 36 | /// ``` 37 | #[derive(Clone)] 38 | pub struct ConnectionManager { 39 | inner: Arc>>, 40 | } 41 | 42 | impl ConnectionManager { 43 | pub fn new>(database_url: S) -> Self { 44 | Self { 45 | inner: Arc::new(Mutex::new(r2d2::ConnectionManager::new(database_url))), 46 | } 47 | } 48 | 49 | async fn run_blocking(&self, f: F) -> R 50 | where 51 | R: Send + 'static, 52 | F: Send + 'static + FnOnce(&r2d2::ConnectionManager) -> R, 53 | { 54 | let cloned = self.inner.clone(); 55 | tokio::task::spawn_blocking(move || f(&*cloned.lock().unwrap())) 56 | .await 57 | // Intentionally panic if the inner closure panics. 58 | .unwrap() 59 | } 60 | } 61 | 62 | #[async_trait] 63 | impl bb8::ManageConnection for ConnectionManager 64 | where 65 | T: R2D2Connection + Send + 'static, 66 | { 67 | type Connection = Connection; 68 | type Error = ConnectionError; 69 | 70 | async fn connect(&self) -> Result { 71 | self.run_blocking(|m| m.connect()) 72 | .await 73 | .map(Connection::new) 74 | .map_err(ConnectionError::Connection) 75 | } 76 | 77 | async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { 78 | let c = Connection(conn.0.clone()); 79 | self.run_blocking(move |m| { 80 | m.is_valid(&mut *c.inner())?; 81 | Ok(()) 82 | }) 83 | .await 84 | } 85 | 86 | fn has_broken(&self, _: &mut Self::Connection) -> bool { 87 | // Diesel returns this value internally. We have no way of calling the 88 | // inner method without blocking as this method is not async, but `bb8` 89 | // indicates that this method is not mandatory. 90 | false 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | //! bb8-diesel allows the bb8 asynchronous connection pool 2 | //! to be used underneath Diesel. 3 | //! 4 | //! This is currently implemented against Diesel's synchronous 5 | //! API, with calls to [`tokio::task::spawn_blocking`] to safely 6 | //! perform synchronous operations from an asynchronous task. 7 | 8 | use diesel::result::Error as DieselError; 9 | use diesel::OptionalExtension as OtherOptionalExtension; 10 | use thiserror::Error; 11 | 12 | /// Syntactic sugar around a Result returning an [`ConnectionError`]. 13 | pub type ConnectionResult = Result; 14 | 15 | /// Errors returned directly from Connection. 16 | #[derive(Error, Debug)] 17 | pub enum ConnectionError { 18 | #[error("Connection error: {0}")] 19 | Connection(#[from] diesel::r2d2::Error), 20 | 21 | #[error("Failed to issue a query: {0}")] 22 | Query(#[from] DieselError), 23 | } 24 | 25 | /// Syntactic sugar around a Result returning an [`PoolError`]. 26 | pub type PoolResult = Result; 27 | 28 | /// Async variant of [diesel::prelude::OptionalExtension]. 29 | pub trait OptionalExtension { 30 | fn optional(self) -> Result, ConnectionError>; 31 | } 32 | 33 | impl OptionalExtension for Result { 34 | fn optional(self) -> Result, ConnectionError> { 35 | let self_as_query_result: diesel::QueryResult = match self { 36 | Ok(value) => Ok(value), 37 | Err(ConnectionError::Query(error_kind)) => Err(error_kind), 38 | Err(e) => return Err(e), 39 | }; 40 | 41 | self_as_query_result 42 | .optional() 43 | .map_err(ConnectionError::Query) 44 | } 45 | } 46 | 47 | /// Describes an error performing an operation from a connection pool. 48 | /// 49 | /// This is a superset of [`ConnectionError`] which also may 50 | /// propagate errors attempting to access the connection pool. 51 | #[derive(Error, Debug)] 52 | pub enum PoolError { 53 | #[error("Failure accessing a connection: {0}")] 54 | Connection(#[from] ConnectionError), 55 | 56 | #[error("BB8 Timeout accessing connection")] 57 | Timeout, 58 | } 59 | 60 | impl From for PoolError { 61 | fn from(error: DieselError) -> Self { 62 | PoolError::Connection(ConnectionError::Query(error)) 63 | } 64 | } 65 | 66 | impl From> for PoolError { 67 | fn from(error: bb8::RunError) -> Self { 68 | match error { 69 | bb8::RunError::User(e) => PoolError::Connection(e), 70 | bb8::RunError::TimedOut => PoolError::Timeout, 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! bb8-diesel allows the bb8 asynchronous connection pool 2 | //! to be used underneath Diesel. 3 | //! 4 | //! This is currently implemented against Diesel's synchronous 5 | //! API, with calls to [`tokio::task::spawn_blocking`] to safely 6 | //! perform synchronous operations from an asynchronous task. 7 | 8 | mod async_traits; 9 | mod connection; 10 | mod connection_manager; 11 | mod error; 12 | 13 | pub use async_traits::{ 14 | AsyncConnection, AsyncR2D2Connection, AsyncRunQueryDsl, AsyncSaveChangesDsl, 15 | AsyncSimpleConnection, 16 | }; 17 | pub use connection::Connection; 18 | pub use connection_manager::ConnectionManager; 19 | pub use error::{ConnectionError, ConnectionResult, OptionalExtension, PoolError, PoolResult}; 20 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | This directory includes some integration tests, specifically with the 4 | CockroachDB database, which uses a PostgreSQL protocol. 5 | 6 | These tests rely on the `crdb-harness` crate to download CockroachDB 7 | at a particular version. 8 | -------------------------------------------------------------------------------- /tests/test.rs: -------------------------------------------------------------------------------- 1 | // This Source Code Form is subject to the terms of the Mozilla Public 2 | // License, v. 2.0. If a copy of the MPL was not distributed with this 3 | // file, You can obtain one at https://mozilla.org/MPL/2.0/. 4 | 5 | use async_bb8_diesel::{ 6 | AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, AsyncSimpleConnection, ConnectionError, 7 | }; 8 | use crdb_harness::{CockroachInstance, CockroachStarterBuilder}; 9 | use diesel::OptionalExtension; 10 | use diesel::{pg::PgConnection, prelude::*}; 11 | 12 | table! { 13 | user (id) { 14 | id -> Int4, 15 | name -> Text, 16 | } 17 | } 18 | 19 | const SCHEMA: &'static str = r#" 20 | CREATE DATABASE test; 21 | CREATE TABLE IF NOT EXISTS test.public.user ( 22 | id INT4 PRIMARY KEY, 23 | name STRING(512) 24 | ); 25 | "#; 26 | 27 | #[derive(AsChangeset, Insertable, Queryable, PartialEq, Clone)] 28 | #[diesel(table_name = user)] 29 | pub struct User { 30 | pub id: i32, 31 | pub name: String, 32 | } 33 | 34 | #[derive(AsChangeset, Identifiable)] 35 | #[diesel(table_name = user)] 36 | pub struct UserUpdate<'a> { 37 | pub id: i32, 38 | pub name: &'a str, 39 | } 40 | 41 | // Creates a new CRDB database under test 42 | async fn test_start() -> CockroachInstance { 43 | let crdb = CockroachStarterBuilder::new() 44 | .redirect_stdio_to_files() 45 | .build() 46 | .expect("Failed to create CockroachDB builder") 47 | .start() 48 | .await 49 | .expect("Failed to start CockroachDB"); 50 | 51 | let client = crdb.connect().await.expect("Could not connect to database"); 52 | 53 | client 54 | .batch_execute(&SCHEMA) 55 | .await 56 | .expect("Failed to initialize database"); 57 | 58 | crdb 59 | } 60 | 61 | // Terminates a test CRDB database 62 | async fn test_end(mut crdb: CockroachInstance) { 63 | crdb.cleanup() 64 | .await 65 | .expect("Failed to clean up CockroachDB"); 66 | } 67 | 68 | #[tokio::test] 69 | async fn test_insert_load_update_delete() { 70 | let crdb = test_start().await; 71 | 72 | let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); 73 | let pool = bb8::Pool::builder().build(manager).await.unwrap(); 74 | let conn = pool.get().await.unwrap(); 75 | 76 | use user::dsl; 77 | // Insert by values 78 | let _ = diesel::insert_into(dsl::user) 79 | .values((dsl::id.eq(0), dsl::name.eq("Jim"))) 80 | .execute_async(&*conn) 81 | .await 82 | .unwrap(); 83 | 84 | // Insert by structure 85 | let _ = diesel::insert_into(dsl::user) 86 | .values(User { 87 | id: 1, 88 | name: "Xiang".to_string(), 89 | }) 90 | .execute_async(&*conn) 91 | .await 92 | .unwrap(); 93 | 94 | // Load 95 | let users = dsl::user.get_results_async::(&*conn).await.unwrap(); 96 | assert_eq!(users.len(), 2); 97 | 98 | // Update 99 | let _ = diesel::update(dsl::user) 100 | .filter(dsl::id.eq(0)) 101 | .set(dsl::name.eq("Jim, But Different")) 102 | .execute_async(&*conn) 103 | .await 104 | .unwrap(); 105 | 106 | // Update via save_changes 107 | let update = &UserUpdate { 108 | id: 0, 109 | name: "The Artist Formerly Known As Jim", 110 | }; 111 | let _ = update.save_changes_async::(&*conn).await.unwrap(); 112 | 113 | // Delete 114 | let _ = diesel::delete(dsl::user) 115 | .filter(dsl::id.eq(0)) 116 | .execute_async(&*conn) 117 | .await 118 | .unwrap(); 119 | 120 | test_end(crdb).await; 121 | } 122 | 123 | #[tokio::test] 124 | async fn test_transaction() { 125 | let crdb = test_start().await; 126 | 127 | let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); 128 | let pool = bb8::Pool::builder().build(manager).await.unwrap(); 129 | let conn = pool.get().await.unwrap(); 130 | 131 | use user::dsl; 132 | 133 | // Transaction with multiple operations 134 | conn.transaction_async(|conn| async move { 135 | diesel::insert_into(dsl::user) 136 | .values((dsl::id.eq(3), dsl::name.eq("Sally"))) 137 | .execute_async(&conn) 138 | .await 139 | .unwrap(); 140 | diesel::insert_into(dsl::user) 141 | .values((dsl::id.eq(4), dsl::name.eq("Arjun"))) 142 | .execute_async(&conn) 143 | .await 144 | .unwrap(); 145 | Ok::<(), ConnectionError>(()) 146 | }) 147 | .await 148 | .unwrap(); 149 | 150 | test_end(crdb).await; 151 | } 152 | 153 | #[tokio::test] 154 | async fn test_transaction_automatic_retry_success_case() { 155 | let crdb = test_start().await; 156 | 157 | let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); 158 | let pool = bb8::Pool::builder().build(manager).await.unwrap(); 159 | let conn = pool.get().await.unwrap(); 160 | 161 | use user::dsl; 162 | 163 | // Transaction that can retry but does not need to. 164 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 165 | conn.transaction_async_with_retry( 166 | |conn| async move { 167 | assert!(conn.transaction_depth().await.unwrap() > 0); 168 | diesel::insert_into(dsl::user) 169 | .values((dsl::id.eq(3), dsl::name.eq("Sally"))) 170 | .execute_async(&conn) 171 | .await?; 172 | Ok(()) 173 | }, 174 | || async { panic!("Should not attempt to retry this operation") }, 175 | ) 176 | .await 177 | .expect("Transaction failed"); 178 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 179 | 180 | test_end(crdb).await; 181 | } 182 | 183 | #[tokio::test] 184 | async fn test_transaction_automatic_retry_explicit_rollback() { 185 | let crdb = test_start().await; 186 | 187 | let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); 188 | let pool = bb8::Pool::builder().build(manager).await.unwrap(); 189 | let conn = pool.get().await.unwrap(); 190 | 191 | use std::sync::{Arc, Mutex}; 192 | 193 | let transaction_attempted_count = Arc::new(Mutex::new(0)); 194 | let should_retry_query_count = Arc::new(Mutex::new(0)); 195 | 196 | // Test a transaction that: 197 | // 198 | // 1. Retries on the first call 199 | // 2. Explicitly rolls back on the second call 200 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 201 | let err = conn 202 | .transaction_async_with_retry( 203 | |_conn| { 204 | let transaction_attempted_count = transaction_attempted_count.clone(); 205 | async move { 206 | let mut count = transaction_attempted_count.lock().unwrap(); 207 | *count += 1; 208 | 209 | if *count < 2 { 210 | eprintln!("test: Manually restarting txn"); 211 | return Err::<(), _>(diesel::result::Error::DatabaseError( 212 | diesel::result::DatabaseErrorKind::SerializationFailure, 213 | Box::new("restart transaction".to_string()), 214 | )); 215 | } 216 | eprintln!("test: Manually rolling back txn"); 217 | return Err(diesel::result::Error::RollbackTransaction); 218 | } 219 | }, 220 | || async { 221 | *should_retry_query_count.lock().unwrap() += 1; 222 | true 223 | }, 224 | ) 225 | .await 226 | .expect_err("Transaction should have failed"); 227 | 228 | assert_eq!(err, diesel::result::Error::RollbackTransaction); 229 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 230 | 231 | // The transaction closure should have been attempted twice, but 232 | // we should have only asked whether or not to retry once -- after 233 | // the first failure, but not the second. 234 | assert_eq!(*transaction_attempted_count.lock().unwrap(), 2); 235 | assert_eq!(*should_retry_query_count.lock().unwrap(), 1); 236 | 237 | test_end(crdb).await; 238 | } 239 | 240 | #[tokio::test] 241 | async fn test_transaction_automatic_retry_injected_errors() { 242 | let crdb = test_start().await; 243 | 244 | let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); 245 | let pool = bb8::Pool::builder().build(manager).await.unwrap(); 246 | let conn = pool.get().await.unwrap(); 247 | 248 | use std::sync::{Arc, Mutex}; 249 | 250 | let transaction_attempted_count = Arc::new(Mutex::new(0)); 251 | let should_retry_query_count = Arc::new(Mutex::new(0)); 252 | 253 | // Tests a transaction that is forced to retry by CockroachDB. 254 | // 255 | // By setting this session variable, we expect that: 256 | // - "any statement executed inside of an explicit transaction (with the 257 | // exception of SET statements) will return a transaction retry error." 258 | // - "after the 3rd retry error, the transaction will proceed as 259 | // normal" 260 | // 261 | // See: https://www.cockroachlabs.com/docs/v23.1/transaction-retry-error-example#test-transaction-retry-logic 262 | // for more details 263 | const EXPECTED_ERR_COUNT: usize = 3; 264 | conn.batch_execute_async("SET inject_retry_errors_enabled = true") 265 | .await 266 | .expect("Failed to inject error"); 267 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 268 | conn.transaction_async_with_retry( 269 | |conn| { 270 | let transaction_attempted_count = transaction_attempted_count.clone(); 271 | async move { 272 | *transaction_attempted_count.lock().unwrap() += 1; 273 | 274 | use user::dsl; 275 | let _ = diesel::insert_into(dsl::user) 276 | .values((dsl::id.eq(0), dsl::name.eq("Jim"))) 277 | .execute_async(&conn) 278 | .await?; 279 | Ok(()) 280 | } 281 | }, 282 | || async { 283 | *should_retry_query_count.lock().unwrap() += 1; 284 | true 285 | }, 286 | ) 287 | .await 288 | .expect("Transaction should have succeeded"); 289 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 290 | 291 | // The transaction closure should have been attempted twice, but 292 | // we should have only asked whether or not to retry once -- after 293 | // the first failure, but not the second. 294 | assert_eq!( 295 | *transaction_attempted_count.lock().unwrap(), 296 | EXPECTED_ERR_COUNT + 1 297 | ); 298 | assert_eq!( 299 | *should_retry_query_count.lock().unwrap(), 300 | EXPECTED_ERR_COUNT 301 | ); 302 | 303 | test_end(crdb).await; 304 | } 305 | 306 | #[tokio::test] 307 | async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors() { 308 | let crdb = test_start().await; 309 | 310 | let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); 311 | let pool = bb8::Pool::builder().build(manager).await.unwrap(); 312 | let conn = pool.get().await.unwrap(); 313 | 314 | // Test a transaction that: 315 | // 316 | // Fails with a non-retryable error. It should exit immediately. 317 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 318 | assert_eq!( 319 | conn.transaction_async_with_retry( 320 | |_| async { Err::<(), _>(diesel::result::Error::NotFound) }, 321 | || async { panic!("Should not attempt to retry this operation") } 322 | ) 323 | .await 324 | .expect_err("Transaction should have failed"), 325 | diesel::result::Error::NotFound, 326 | ); 327 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 328 | 329 | test_end(crdb).await; 330 | } 331 | 332 | #[tokio::test] 333 | async fn test_transaction_automatic_retry_nested_transactions_fail() { 334 | let crdb = test_start().await; 335 | 336 | let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); 337 | let pool = bb8::Pool::builder().build(manager).await.unwrap(); 338 | let conn = pool.get().await.unwrap(); 339 | 340 | #[derive(Debug, PartialEq)] 341 | struct OnlyReturnFromOuterTransaction {} 342 | 343 | // This outer transaction should succeed immediately... 344 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 345 | assert_eq!( 346 | OnlyReturnFromOuterTransaction {}, 347 | conn.transaction_async_with_retry( 348 | |conn| async move { 349 | // ... but this inner transaction should fail! We do not support 350 | // retryable nested transactions. 351 | let err = conn 352 | .transaction_async_with_retry( 353 | |_| async { 354 | panic!("Shouldn't run"); 355 | 356 | // Adding this unreachable statement for type inference 357 | #[allow(unreachable_code)] 358 | Ok(()) 359 | }, 360 | || async { panic!("Shouldn't retry inner transaction") }, 361 | ) 362 | .await 363 | .expect_err("Nested transaction should have failed"); 364 | assert_eq!(err, diesel::result::Error::AlreadyInTransaction); 365 | 366 | // We still want to show that control exists within the outer 367 | // transaction, so we explicitly return here. 368 | Ok(OnlyReturnFromOuterTransaction {}) 369 | }, 370 | || async { panic!("Shouldn't retry outer transaction") }, 371 | ) 372 | .await 373 | .expect("Transaction should have succeeded") 374 | ); 375 | assert_eq!(conn.transaction_depth().await.unwrap(), 0); 376 | 377 | test_end(crdb).await; 378 | } 379 | 380 | #[tokio::test] 381 | async fn test_transaction_custom_error() { 382 | let crdb = test_start().await; 383 | 384 | let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); 385 | let pool = bb8::Pool::builder().build(manager).await.unwrap(); 386 | let conn = pool.get().await.unwrap(); 387 | 388 | // Demonstrates an error which may be returned from transactions. 389 | #[derive(thiserror::Error, Debug)] 390 | enum MyError { 391 | #[error("DB error")] 392 | Db(#[from] ConnectionError), 393 | 394 | #[error("Custom transaction error")] 395 | Other, 396 | } 397 | 398 | impl From for MyError { 399 | fn from(error: diesel::result::Error) -> Self { 400 | MyError::Db(ConnectionError::Query(error)) 401 | } 402 | } 403 | 404 | use user::dsl; 405 | 406 | // Transaction returning custom error types. 407 | let _: MyError = conn 408 | .transaction_async(|conn| async move { 409 | diesel::insert_into(dsl::user) 410 | .values((dsl::id.eq(1), dsl::name.eq("Ishmael"))) 411 | .execute_async(&conn) 412 | .await?; 413 | return Err::<(), MyError>(MyError::Other {}); 414 | }) 415 | .await 416 | .unwrap_err(); 417 | 418 | test_end(crdb).await; 419 | } 420 | 421 | #[tokio::test] 422 | async fn test_optional_extension() { 423 | let crdb = test_start().await; 424 | 425 | let manager = async_bb8_diesel::ConnectionManager::::new(&crdb.pg_config().url); 426 | let pool = bb8::Pool::builder().build(manager).await.unwrap(); 427 | let conn = pool.get().await.unwrap(); 428 | 429 | use user::dsl; 430 | 431 | // Access the result via OptionalExtension 432 | assert!(dsl::user 433 | .filter(dsl::id.eq(12345)) 434 | .first_async::(&*conn) 435 | .await 436 | .optional() 437 | .unwrap() 438 | .is_none()); 439 | 440 | test_end(crdb).await; 441 | } 442 | --------------------------------------------------------------------------------