diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index a478e11d05..55ebb7e520 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -69,11 +69,17 @@ where T: Encode<'q, Postgres> + Type, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + let type_info = if self.len() < 1 { + T::type_info() + } else { + self[0].produces().unwrap_or_else(T::type_info) + }; + buf.extend(&1_i32.to_be_bytes()); // number of dimensions buf.extend(&0_i32.to_be_bytes()); // flags // element type - match T::type_info().0 { + match type_info.0 { PgType::DeclareWithName(name) => buf.patch_type_by_name(&name), ty => { diff --git a/sqlx-core/src/postgres/types/record.rs b/sqlx-core/src/postgres/types/record.rs index 8ae55c2c0a..1d9acab9b3 100644 --- a/sqlx-core/src/postgres/types/record.rs +++ b/sqlx-core/src/postgres/types/record.rs @@ -38,7 +38,7 @@ impl<'a> PgRecordEncoder<'a> { 'a: 'q, T: Encode<'q, Postgres> + Type, { - let ty = T::type_info(); + let ty = value.produces().unwrap_or_else(T::type_info); if let PgType::DeclareWithName(name) = ty.0 { // push a hole for this type ID diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 51dfbc6d37..c244a93dc1 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1201,6 +1201,128 @@ async fn it_can_copy_out() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_encodes_custom_array_issue_1504() -> anyhow::Result<()> { + use sqlx::encode::IsNull; + use sqlx::postgres::{PgArgumentBuffer, PgTypeInfo}; + use sqlx::{Decode, Encode, Type, ValueRef}; + + #[derive(Debug, PartialEq)] + enum Value { + String(String), + Number(i32), + Array(Vec), + } + + impl<'r> Decode<'r, Postgres> for Value { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> std::result::Result> { + let typ = value.type_info().into_owned(); + + if typ == PgTypeInfo::with_name("text") { + let s = >::decode(value)?; + + Ok(Self::String(s)) + } else if typ == PgTypeInfo::with_name("int4") { + let n = >::decode(value)?; + + Ok(Self::Number(n)) + } else if typ == PgTypeInfo::with_name("_text") { + let arr = Vec::::decode(value)?; + let v = arr.into_iter().map(|s| Value::String(s)).collect(); + + Ok(Self::Array(v)) + } else if typ == PgTypeInfo::with_name("_int4") { + let arr = Vec::::decode(value)?; + let v = arr.into_iter().map(|n| Value::Number(n)).collect(); + + Ok(Self::Array(v)) + } else { + Err("unknown type".into()) + } + } + } + + impl Encode<'_, Postgres> for Value { + fn produces(&self) -> Option { + match self { + Self::Array(a) => { + if a.len() < 1 { + return Some(PgTypeInfo::with_name("_text")); + } + + match a[0] { + Self::String(_) => Some(PgTypeInfo::with_name("_text")), + Self::Number(_) => Some(PgTypeInfo::with_name("_int4")), + Self::Array(_) => None, + } + } + Self::String(_) => Some(PgTypeInfo::with_name("text")), + Self::Number(_) => Some(PgTypeInfo::with_name("int4")), + } + } + + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + match self { + Value::String(s) => >::encode_by_ref(s, buf), + Value::Number(n) => >::encode_by_ref(n, buf), + Value::Array(arr) => arr.encode(buf), + } + } + } + + impl Type for Value { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("unknown") + } + + fn compatible(ty: &PgTypeInfo) -> bool { + [ + PgTypeInfo::with_name("text"), + PgTypeInfo::with_name("_text"), + PgTypeInfo::with_name("int4"), + PgTypeInfo::with_name("_int4"), + ] + .contains(ty) + } + } + + let mut conn = new::().await?; + + let (row,): (Value,) = sqlx::query_as("SELECT $1::text[] as Dummy") + .bind(Value::Array(vec![ + Value::String("Test 0".to_string()), + Value::String("Test 1".to_string()), + ])) + .fetch_one(&mut conn) + .await?; + + assert_eq!( + row, + Value::Array(vec![ + Value::String("Test 0".to_string()), + Value::String("Test 1".to_string()), + ]) + ); + + let (row,): (Value,) = sqlx::query_as("SELECT $1::int4[] as Dummy") + .bind(Value::Array(vec![ + Value::Number(3), + Value::Number(2), + Value::Number(1), + ])) + .fetch_one(&mut conn) + .await?; + + assert_eq!( + row, + Value::Array(vec![Value::Number(3), Value::Number(2), Value::Number(1)]) + ); + + Ok(()) +} + #[sqlx_macros::test] async fn test_issue_1254() -> anyhow::Result<()> { #[derive(sqlx::Type)]