Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PgBindIter for encoding and use it as the implementation encoding &[T] #3651

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
152 changes: 152 additions & 0 deletions sqlx-postgres/src/bind_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use sqlx_core::{
database::Database,
encode::{Encode, IsNull},
error::BoxDynError,
types::Type,
};

use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};

// not exported but pub because it is used in the extension trait
pub struct PgBindIter<I>(I);

/// Iterator extension trait enabling iterators to encode arrays in Postgres.
///
/// Because of the blanket impl of `PgHasArrayType` for all references
/// we can borrow instead of needing to clone or copy in the iterators
/// and it still works
///
/// Previously, 3 separate arrays would be needed in this example which
/// requires iterating 3 times to collect items into the array and then
/// iterating over them again to encode.
///
/// This now requires only iterating over the array once for each field
/// while using less memory giving both speed and memory usage improvements
/// along with allowing much more flexibility in the underlying collection.
///
/// ```rust,no_run
/// # async fn test_bind_iter() -> Result<(), sqlx::error::BoxDynError> {
/// # use sqlx::types::chrono::{DateTime, Utc};
/// # use sqlx::Connection;
/// # fn people() -> &'static [Person] {
/// # &[]
/// # }
/// # let mut conn = <sqlx::Postgres as sqlx::Database>::Connection::connect("dummyurl").await?;
/// use sqlx::postgres::PgBindIterExt;
///
/// #[derive(sqlx::FromRow)]
/// struct Person {
/// id: i64,
/// name: String,
/// birthdate: DateTime<Utc>,
/// }
///
/// # let people: &[Person] = people();
/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
/// .bind(people.iter().map(|p| p.id).bind_iter())
/// .bind(people.iter().map(|p| &p.name).bind_iter())
/// .bind(people.iter().map(|p| &p.birthdate).bind_iter())
/// .execute(&mut conn)
/// .await?;
///
/// # Ok(())
/// # }
/// ```
pub trait PgBindIterExt: Iterator + Sized {
fn bind_iter(self) -> PgBindIter<Self>;
}

impl<I: Iterator + Sized> PgBindIterExt for I {
fn bind_iter(self) -> PgBindIter<I> {
PgBindIter(self)
}
}

impl<I> Type<Postgres> for PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + PgHasArrayType,
{
fn type_info() -> <Postgres as Database>::TypeInfo {
<I as Iterator>::Item::array_type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
<I as Iterator>::Item::array_compatible(ty)
}
}

impl<'q, I> PgBindIter<I>
where
I: Iterator,
<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_inner(
// need ownership to iterate
mut iter: I,
buf: &mut PgArgumentBuffer,
) -> Result<IsNull, BoxDynError> {
let lower_size_hint = iter.size_hint().0;
let first = iter.next();
let type_info = first
.as_ref()
.and_then(Encode::produces)
.unwrap_or_else(<I as Iterator>::Item::type_info);

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags

match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),

ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}

let len_start = buf.len();
buf.extend(0_i32.to_be_bytes()); // len (unknown so far)
buf.extend(1_i32.to_be_bytes()); // lower bound

match first {
Some(first) => buf.encode(first)?,
None => return Ok(IsNull::No),
}

let mut count = 1_i32;
const MAX: usize = i32::MAX as usize - 1;

for value in (&mut iter).take(MAX) {
buf.encode(value)?;
count += 1;
}

const OVERFLOW: usize = i32::MAX as usize + 1;
if iter.next().is_some() {
let iter_size = std::cmp::max(lower_size_hint, OVERFLOW);
return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into());
}

// set the length now that we know what it is.
buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes());

Ok(IsNull::No)
}
}

impl<'q, I> Encode<'q, Postgres> for PgBindIter<I>
where
// Clone is required for the encode_by_ref call since we can't iterate with a shared reference
I: Iterator + Clone,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan I had for this was to use Cell<Option<I>> since the Clone bound could theoretically be hard to satisfy for some iterators.

Using Cell would make it !Sync and only able to used once, but those seem like pretty reasonable tradeoffs. I don't expect anyone to need to share this across threads or or need to bind the same iterator more than once (they could just use the same placeholder number to refer to the existing binding or create a new iterator).

It seems like maybe we should split Encode into two traits: one that encodes by-value and one by-reference, then a blanket impl of the former over the latter. The by-reference trait would be the one that most types would implement, but the by-value trait would be the one that all the APIs actually accept.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I can make the change. I don't think it will affect the send-ness of the execution future so that should be fine.

I agree with splitting up the Encode trait. It would make this simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abonander I've made the change to Cell<Option<I>>

<I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
Self::encode_inner(self.0.clone(), buf)
}
fn encode(self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError>
where
Self: Sized,
{
Self::encode_inner(self.0, buf)
}
}
2 changes: 2 additions & 0 deletions sqlx-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::executor::Executor;

mod advisory_lock;
mod arguments;
mod bind_iter;
mod column;
mod connection;
mod copy;
Expand Down Expand Up @@ -44,6 +45,7 @@ pub(crate) use sqlx_core::driver_prelude::*;

pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey};
pub use arguments::{PgArgumentBuffer, PgArguments};
pub use bind_iter::PgBindIterExt;
pub use column::PgColumn;
pub use connection::PgConnection;
pub use copy::{PgCopyIn, PgPoolCopyExt};
Expand Down
32 changes: 3 additions & 29 deletions sqlx-postgres/src/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::borrow::Cow;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::type_info::PgType;
use crate::types::Oid;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
Expand Down Expand Up @@ -156,39 +155,14 @@ where
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
let type_info = self
.first()
.and_then(Encode::produces)
.unwrap_or_else(T::type_info);

buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags

// element type
match type_info.0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
PgType::DeclareArrayOf(array) => buf.patch_array_type(array),

ty => {
buf.extend(&ty.oid().0.to_be_bytes());
}
}

let array_len = i32::try_from(self.len()).map_err(|_| {
// do the length check early to avoid doing unnecessary work
i32::try_from(self.len()).map_err(|_| {
format!(
"encoded array length is too large for Postgres: {}",
self.len()
)
})?;

buf.extend(array_len.to_be_bytes()); // len
buf.extend(&1_i32.to_be_bytes()); // lower bound

for element in self.iter() {
buf.encode(element)?;
}

Ok(IsNull::No)
crate::PgBindIterExt::bind_iter(self.iter()).encode(buf)
}
}

Expand Down
58 changes: 58 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2042,3 +2042,61 @@ async fn test_issue_3052() {
"expected encode error, got {too_large_error:?}",
);
}

#[sqlx_macros::test]
async fn test_bind_iter() -> anyhow::Result<()> {
use sqlx::postgres::PgBindIterExt;
use sqlx::types::chrono::{DateTime, Utc};

let mut conn = new::<Postgres>().await?;

#[derive(sqlx::FromRow, PartialEq, Debug)]
struct Person {
id: i64,
name: String,
birthdate: DateTime<Utc>,
}

let people: Vec<Person> = vec![
Person {
id: 1,
name: "Alice".into(),
birthdate: "1984-01-01T00:00:00Z".parse().unwrap(),
},
Person {
id: 2,
name: "Bob".into(),
birthdate: "2000-01-01T00:00:00Z".parse().unwrap(),
},
];

sqlx::query(
r#"
create temporary table person(
id int8 primary key,
name text not null,
birthdate timestamptz not null
)"#,
)
.execute(&mut conn)
.await?;

let rows_affected =
sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
// owned value
.bind(people.iter().map(|p| p.id).bind_iter())
// borrowed value
.bind(people.iter().map(|p| &p.name).bind_iter())
.bind(people.iter().map(|p| &p.birthdate).bind_iter())
.execute(&mut conn)
.await?
.rows_affected();
assert_eq!(rows_affected, 2);

let p_query = sqlx::query_as::<_, Person>("select * from person order by id")
.fetch_all(&mut conn)
.await?;

assert_eq!(people, p_query);
Ok(())
}
Loading