Skip to content

Commit

Permalink
feat: Implement before_connect callback to modify connect options.
Browse files Browse the repository at this point in the history
Allows the user to see and maybe modify the connect options before
each attempt to connect to a database. May be used in a number of
ways, e.g.:
 - adding jitter to connection lifetime
 - validating/setting a per-connection password
 - using a custom server discovery process
  • Loading branch information
jrasanen committed Oct 14, 2024
1 parent 028084b commit 76b5c16
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 9 deletions.
45 changes: 36 additions & 9 deletions sqlx-core/src/pool/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crossbeam_queue::ArrayQueue;

use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser};

use std::borrow::Cow;
use std::cmp;
use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
Expand Down Expand Up @@ -324,6 +325,36 @@ impl<DB: Database> PoolInner<DB> {
Ok(acquired)
}

/// Attempts to get connect options, possibly modify using before_connect, then connect.
///
/// Wrapping this code in a timeout allows the total time taken for these steps to
/// be bounded by the connection deadline.
async fn get_connect_options_and_connect(
self: &Arc<Self>,
num_attempts: u32,
) -> Result<DB::Connection, Error> {
// clone the connect options arc so it can be used without holding the RwLockReadGuard
// across an async await point
let connect_options_arc = self
.connect_options
.read()
.expect("write-lock holder panicked")
.clone();

let connect_options = if let Some(callback) = &self.options.before_connect {
callback(connect_options_arc.as_ref(), num_attempts)
.await
.map_err(|error| {
tracing::error!(%error, "error returned from before_connect");
error
})?
} else {
Cow::Borrowed(connect_options_arc.as_ref())
};

connect_options.connect().await
}

pub(super) async fn connect(
self: &Arc<Self>,
deadline: Instant,
Expand All @@ -335,21 +366,17 @@ impl<DB: Database> PoolInner<DB> {

let mut backoff = Duration::from_millis(10);
let max_backoff = deadline_as_timeout(deadline)? / 5;
let mut num_attempts: u32 = 0;

loop {
let timeout = deadline_as_timeout(deadline)?;

// clone the connect options arc so it can be used without holding the RwLockReadGuard
// across an async await point
let connect_options = self
.connect_options
.read()
.expect("write-lock holder panicked")
.clone();
num_attempts += 1;

// result here is `Result<Result<C, Error>, TimeoutError>`
// if this block does not return, sleep for the backoff timeout and try again
match crate::rt::timeout(timeout, connect_options.connect()).await {
match crate::rt::timeout(timeout, self.get_connect_options_and_connect(num_attempts))
.await
{
// successfully established connection
Ok(Ok(mut raw)) => {
// See comment on `PoolOptions::after_connect`
Expand Down
63 changes: 63 additions & 0 deletions sqlx-core/src/pool/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::pool::inner::PoolInner;
use crate::pool::Pool;
use futures_core::future::BoxFuture;
use log::LevelFilter;
use std::borrow::Cow;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::time::{Duration, Instant};
Expand Down Expand Up @@ -44,6 +45,18 @@ use std::time::{Duration, Instant};
/// the perspectives of both API designer and consumer.
pub struct PoolOptions<DB: Database> {
pub(crate) test_before_acquire: bool,
pub(crate) before_connect: Option<
Arc<
dyn Fn(
&<DB::Connection as Connection>::Options,
u32,
)
-> BoxFuture<'_, Result<Cow<'_, <DB::Connection as Connection>::Options>, Error>>
+ 'static
+ Send
+ Sync,
>,
>,
pub(crate) after_connect: Option<
Arc<
dyn Fn(&mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'_, Result<(), Error>>
Expand Down Expand Up @@ -94,6 +107,7 @@ impl<DB: Database> Clone for PoolOptions<DB> {
fn clone(&self) -> Self {
PoolOptions {
test_before_acquire: self.test_before_acquire,
before_connect: self.before_connect.clone(),
after_connect: self.after_connect.clone(),
before_acquire: self.before_acquire.clone(),
after_release: self.after_release.clone(),
Expand Down Expand Up @@ -143,6 +157,7 @@ impl<DB: Database> PoolOptions<DB> {
pub fn new() -> Self {
Self {
// User-specifiable routines
before_connect: None,
after_connect: None,
before_acquire: None,
after_release: None,
Expand Down Expand Up @@ -339,6 +354,54 @@ impl<DB: Database> PoolOptions<DB> {
self
}

/// Perform an asynchronous action before connecting to the database.
///
/// This operation is performed on every attempt to connect, including retries. The
/// current `ConnectOptions` is passed, and this may be passed unchanged, or modified
/// after cloning. The current connection attempt is passed as the second parameter
/// (starting at 1).
///
/// If the operation returns with an error, then the connection attempt fails without
/// attempting further retries. The operation therefore may need to implement error
/// handling and/or value caching to avoid failing the connection attempt.
///
/// # Example: Per-Request Authentication
/// This callback may be used to modify values in the database's `ConnectOptions`, before
/// connecting to the database.
///
/// This example is written for PostgreSQL but can likely be adapted to other databases.
///
/// ```no_run
/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
/// use std::borrow::Cow;
/// use sqlx::Executor;
/// use sqlx::postgres::PgPoolOptions;
///
/// let pool = PgPoolOptions::new()
/// .before_connect(move |opts, _num_attempts| Box::pin(async move {
/// Ok(Cow::Owned(opts.clone().password("abc")))
/// }))
/// .connect("postgres:// …").await?;
/// # Ok(())
/// # }
/// ```
///
/// For a discussion on why `Box::pin()` is required, see [the type-level docs][Self].
pub fn before_connect<F>(mut self, callback: F) -> Self
where
for<'c> F: Fn(
&'c <DB::Connection as Connection>::Options,
u32,
)
-> BoxFuture<'c, crate::Result<Cow<'c, <DB::Connection as Connection>::Options>>>
+ 'static
+ Send
+ Sync,
{
self.before_connect = Some(Arc::new(callback));
self
}

/// Perform an asynchronous action after connecting to the database.
///
/// If the operation returns with an error then the error is logged, the connection is closed
Expand Down

0 comments on commit 76b5c16

Please sign in to comment.