From 32f1273565e1cd676e0514655fe8aec67b59b634 Mon Sep 17 00:00:00 2001 From: Charles Samborski Date: Wed, 29 Dec 2021 22:05:15 +0100 Subject: [PATCH] Fix support for Postgres array of custom types (#1483) This commit fixes the array decoder to support custom types. The core of the issue was that the array decoder did not use the type info retrieved from the database. It means that it only supported native types. This commit fixes the issue by using the element type info fetched from the database. A new internal helper method is added to the `PgType` struct: it returns the type info for the inner array element, if available. Closes #1477 --- sqlx-core/src/postgres/connection/describe.rs | 6 +- sqlx-core/src/postgres/type_info.rs | 120 ++++++++++++++++++ sqlx-core/src/postgres/types/array.rs | 7 +- tests/postgres/postgres.rs | 87 +++++++++++++ 4 files changed, 216 insertions(+), 4 deletions(-) diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index 6d7a3f7dc0..3b8d936fb3 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -16,6 +16,7 @@ use std::sync::Arc; /// Describes the type of the `pg_type.typtype` column /// /// See +#[derive(Copy, Clone, Debug, Eq, PartialEq)] enum TypType { Base, Composite, @@ -45,6 +46,7 @@ impl TryFrom for TypType { /// Describes the type of the `pg_type.typcategory` column /// /// See +#[derive(Copy, Clone, Debug, Eq, PartialEq)] enum TypCategory { Array, Boolean, @@ -198,7 +200,9 @@ impl PgConnection { (Ok(TypType::Base), Ok(TypCategory::Array)) => { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { - kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?), + kind: PgTypeKind::Array( + self.maybe_fetch_type_info_by_oid(element, true).await?, + ), name: name.into(), oid, })))) diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index 37c018f798..97d5efa0cc 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -1,5 +1,6 @@ #![allow(dead_code)] +use std::borrow::Cow; use std::fmt::{self, Display, Formatter}; use std::ops::Deref; use std::sync::Arc; @@ -750,6 +751,125 @@ impl PgType { } } } + + /// If `self` is an array type, return the type info for its element. + /// + /// This method should only be called on resolved types: calling it on + /// a type that is merely declared (DeclareWithOid/Name) is a bug. + pub(crate) fn try_array_element(&self) -> Option> { + // We explicitly match on all the `None` cases to ensure an exhaustive match. + match self { + PgType::Bool => None, + PgType::BoolArray => Some(Cow::Owned(PgTypeInfo(PgType::Bool))), + PgType::Bytea => None, + PgType::ByteaArray => Some(Cow::Owned(PgTypeInfo(PgType::Bytea))), + PgType::Char => None, + PgType::CharArray => Some(Cow::Owned(PgTypeInfo(PgType::Char))), + PgType::Name => None, + PgType::NameArray => Some(Cow::Owned(PgTypeInfo(PgType::Name))), + PgType::Int8 => None, + PgType::Int8Array => Some(Cow::Owned(PgTypeInfo(PgType::Int8))), + PgType::Int2 => None, + PgType::Int2Array => Some(Cow::Owned(PgTypeInfo(PgType::Int2))), + PgType::Int4 => None, + PgType::Int4Array => Some(Cow::Owned(PgTypeInfo(PgType::Int4))), + PgType::Text => None, + PgType::TextArray => Some(Cow::Owned(PgTypeInfo(PgType::Text))), + PgType::Oid => None, + PgType::OidArray => Some(Cow::Owned(PgTypeInfo(PgType::Oid))), + PgType::Json => None, + PgType::JsonArray => Some(Cow::Owned(PgTypeInfo(PgType::Json))), + PgType::Point => None, + PgType::PointArray => Some(Cow::Owned(PgTypeInfo(PgType::Point))), + PgType::Lseg => None, + PgType::LsegArray => Some(Cow::Owned(PgTypeInfo(PgType::Lseg))), + PgType::Path => None, + PgType::PathArray => Some(Cow::Owned(PgTypeInfo(PgType::Path))), + PgType::Box => None, + PgType::BoxArray => Some(Cow::Owned(PgTypeInfo(PgType::Box))), + PgType::Polygon => None, + PgType::PolygonArray => Some(Cow::Owned(PgTypeInfo(PgType::Polygon))), + PgType::Line => None, + PgType::LineArray => Some(Cow::Owned(PgTypeInfo(PgType::Line))), + PgType::Cidr => None, + PgType::CidrArray => Some(Cow::Owned(PgTypeInfo(PgType::Cidr))), + PgType::Float4 => None, + PgType::Float4Array => Some(Cow::Owned(PgTypeInfo(PgType::Float4))), + PgType::Float8 => None, + PgType::Float8Array => Some(Cow::Owned(PgTypeInfo(PgType::Float8))), + PgType::Circle => None, + PgType::CircleArray => Some(Cow::Owned(PgTypeInfo(PgType::Circle))), + PgType::Macaddr8 => None, + PgType::Macaddr8Array => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr8))), + PgType::Money => None, + PgType::MoneyArray => Some(Cow::Owned(PgTypeInfo(PgType::Money))), + PgType::Macaddr => None, + PgType::MacaddrArray => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr))), + PgType::Inet => None, + PgType::InetArray => Some(Cow::Owned(PgTypeInfo(PgType::Inet))), + PgType::Bpchar => None, + PgType::BpcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Bpchar))), + PgType::Varchar => None, + PgType::VarcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Varchar))), + PgType::Date => None, + PgType::DateArray => Some(Cow::Owned(PgTypeInfo(PgType::Date))), + PgType::Time => None, + PgType::TimeArray => Some(Cow::Owned(PgTypeInfo(PgType::Time))), + PgType::Timestamp => None, + PgType::TimestampArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamp))), + PgType::Timestamptz => None, + PgType::TimestamptzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamptz))), + PgType::Interval => None, + PgType::IntervalArray => Some(Cow::Owned(PgTypeInfo(PgType::Interval))), + PgType::Timetz => None, + PgType::TimetzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timetz))), + PgType::Bit => None, + PgType::BitArray => Some(Cow::Owned(PgTypeInfo(PgType::Bit))), + PgType::Varbit => None, + PgType::VarbitArray => Some(Cow::Owned(PgTypeInfo(PgType::Varbit))), + PgType::Numeric => None, + PgType::NumericArray => Some(Cow::Owned(PgTypeInfo(PgType::Numeric))), + PgType::Record => None, + PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))), + PgType::Uuid => None, + PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))), + PgType::Jsonb => None, + PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))), + PgType::Int4Range => None, + PgType::Int4RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int4Range))), + PgType::NumRange => None, + PgType::NumRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::NumRange))), + PgType::TsRange => None, + PgType::TsRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TsRange))), + PgType::TstzRange => None, + PgType::TstzRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TstzRange))), + PgType::DateRange => None, + PgType::DateRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::DateRange))), + PgType::Int8Range => None, + PgType::Int8RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int8Range))), + PgType::Jsonpath => None, + PgType::JsonpathArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonpath))), + // There is no `UnknownArray` + PgType::Unknown => None, + // There is no `VoidArray` + PgType::Void => None, + PgType::Custom(ty) => match &ty.kind { + PgTypeKind::Simple => None, + PgTypeKind::Pseudo => None, + PgTypeKind::Domain(_) => None, + PgTypeKind::Composite(_) => None, + PgTypeKind::Array(ref elem_type_info) => Some(Cow::Borrowed(elem_type_info)), + PgTypeKind::Enum(_) => None, + PgTypeKind::Range(_) => None, + }, + PgType::DeclareWithOid(oid) => { + unreachable!("(bug) use of unresolved type declaration [oid={}]", oid); + } + PgType::DeclareWithName(name) => { + unreachable!("(bug) use of unresolved type declaration [name={}]", name); + } + } + } } impl TypeInfo for PgTypeInfo { diff --git a/sqlx-core/src/postgres/types/array.rs b/sqlx-core/src/postgres/types/array.rs index 55ebb7e520..68a0ac385e 100644 --- a/sqlx-core/src/postgres/types/array.rs +++ b/sqlx-core/src/postgres/types/array.rs @@ -1,4 +1,5 @@ use bytes::Buf; +use std::borrow::Cow; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -103,7 +104,6 @@ where T: for<'a> Decode<'a, Postgres> + Type, { fn decode(value: PgValueRef<'r>) -> Result { - let element_type_info; let format = value.format(); match format { @@ -131,7 +131,8 @@ where // the OID of the element let element_type_oid = buf.get_u32(); - element_type_info = PgTypeInfo::try_from_oid(element_type_oid) + let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid) + .or_else(|| value.type_info.try_array_element().map(Cow::into_owned)) .unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid))); // length of the array axis @@ -159,7 +160,7 @@ where PgValueFormat::Text => { // no type is provided from the database for the element - element_type_info = T::type_info(); + let element_type_info = T::type_info(); let s = value.as_str()?; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index c244a93dc1..ac755018bc 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1094,6 +1094,93 @@ CREATE TABLE heating_bills ( Ok(()) } +#[sqlx_macros::test] +async fn it_resolves_custom_type_in_array() -> anyhow::Result<()> { + // Only supported in Postgres 11+ + let mut conn = new::().await?; + if matches!(conn.server_version_num(), Some(version) if version < 110000) { + return Ok(()); + } + + // language=PostgreSQL + conn.execute( + r#" +DROP TABLE IF EXISTS pets; +DROP TYPE IF EXISTS pet_name_and_race; + +CREATE TYPE pet_name_and_race AS ( + name TEXT, + race TEXT +); +CREATE TABLE pets ( + owner TEXT NOT NULL, + name TEXT NOT NULL, + race TEXT NOT NULL, + PRIMARY KEY (owner, name) +); +INSERT INTO pets(owner, name, race) +VALUES + ('Alice', 'Foo', 'cat'); +INSERT INTO pets(owner, name, race) +VALUES + ('Alice', 'Bar', 'dog'); + "#, + ) + .await?; + + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct PetNameAndRace { + name: String, + race: String, + } + + impl sqlx::Type for PetNameAndRace { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name("pet_name_and_race") + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRace { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?; + let name = decoder.try_decode::()?; + let race = decoder.try_decode::()?; + Ok(Self { name, race }) + } + } + + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct PetNameAndRaceArray(Vec); + + impl sqlx::Type for PetNameAndRaceArray { + fn type_info() -> sqlx::postgres::PgTypeInfo { + // Array type name is the name of the element type prefixed with `_` + sqlx::postgres::PgTypeInfo::with_name("_pet_name_and_race") + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRaceArray { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + Ok(Self(Vec::::decode(value)?)) + } + } + + let mut conn = new::().await?; + + let row = sqlx::query("select owner, array_agg(row(name, race)::pet_name_and_race) as pets from pets group by owner") + .fetch_one(&mut conn) + .await?; + + let pets: PetNameAndRaceArray = row.get("pets"); + + assert_eq!(pets.0.len(), 2); + Ok(()) +} + #[sqlx_macros::test] async fn test_pg_server_num() -> anyhow::Result<()> { use sqlx::postgres::PgConnectionInfo;