Skip to content

Commit

Permalink
feat: implement Encode,Decode,Type for Arc<str> and Arc<[u8]>
Browse files Browse the repository at this point in the history
  • Loading branch information
joeydewaal committed Jan 10, 2025
1 parent 6b33766 commit 5bf1fdd
Show file tree
Hide file tree
Showing 10 changed files with 312 additions and 0 deletions.
24 changes: 24 additions & 0 deletions sqlx-mysql/src/types/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
Expand Down Expand Up @@ -83,3 +85,25 @@ impl Decode<'_, MySql> for Vec<u8> {
<&[u8] as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
}
}

impl Type<MySql> for Arc<[u8]> {
fn type_info() -> MySqlTypeInfo {
<[u8] as Type<MySql>>::type_info()
}

fn compatible(ty: &MySqlTypeInfo) -> bool {
<&[u8] as Type<MySql>>::compatible(ty)
}
}

impl Encode<'_, MySql> for Arc<[u8]> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<MySql>>::encode(&**self, buf)
}
}

impl Decode<'_, MySql> for Arc<[u8]> {
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
<&[u8] as Decode<MySql>>::decode(value).map(Into::into)
}
}
23 changes: 23 additions & 0 deletions sqlx-mysql/src/types/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::protocol::text::{ColumnFlags, ColumnType};
use crate::types::Type;
use crate::{MySql, MySqlTypeInfo, MySqlValueRef};
use std::borrow::Cow;
use std::sync::Arc;

impl Type<MySql> for str {
fn type_info() -> MySqlTypeInfo {
Expand Down Expand Up @@ -114,3 +115,25 @@ impl<'r> Decode<'r, MySql> for Cow<'r, str> {
value.as_str().map(Cow::Borrowed)
}
}

impl Type<MySql> for Arc<str> {
fn type_info() -> MySqlTypeInfo {
<str as Type<MySql>>::type_info()
}

fn compatible(ty: &MySqlTypeInfo) -> bool {
<str as Type<MySql>>::compatible(ty)
}
}

impl Encode<'_, MySql> for Arc<str> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&str as Encode<MySql>>::encode(&**self, buf)
}
}

impl Decode<'_, MySql> for Arc<str> {
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
<&str as Decode<MySql>>::decode(value).map(Into::into)
}
}
34 changes: 34 additions & 0 deletions sqlx-postgres/src/types/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use sqlx_core::bytes::Buf;
use sqlx_core::types::Text;
use std::borrow::Cow;
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand Down Expand Up @@ -130,6 +131,19 @@ where
}
}

impl<T> Type<Postgres> for Arc<[T]>
where
T: PgHasArrayType,
{
fn type_info() -> PgTypeInfo {
T::array_type_info()
}

fn compatible(ty: &PgTypeInfo) -> bool {
T::array_compatible(ty)
}
}

impl<'q, T> Encode<'q, Postgres> for Vec<T>
where
for<'a> &'a [T]: Encode<'q, Postgres>,
Expand Down Expand Up @@ -192,6 +206,17 @@ where
}
}

impl<'q, T> Encode<'q, Postgres> for Arc<[T]>
where
for<'a> &'a [T]: Encode<'q, Postgres>,
T: Encode<'q, Postgres>,
{
#[inline]
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<&[T] as Encode<Postgres>>::encode_by_ref(&self.as_ref(), buf)
}
}

impl<'r, T, const N: usize> Decode<'r, Postgres> for [T; N]
where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
Expand Down Expand Up @@ -354,3 +379,12 @@ where
}
}
}

impl<'r, T> Decode<'r, Postgres> for Arc<[T]>
where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
<Vec<T> as Decode<Postgres>>::decode(value).map(Into::into)
}
}
23 changes: 23 additions & 0 deletions sqlx-postgres/src/types/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
Expand Down Expand Up @@ -28,6 +30,12 @@ impl PgHasArrayType for Vec<u8> {
}
}

impl PgHasArrayType for Arc<[u8]> {
fn array_type_info() -> PgTypeInfo {
<[&[u8]] as Type<Postgres>>::type_info()
}
}

impl<const N: usize> PgHasArrayType for [u8; N] {
fn array_type_info() -> PgTypeInfo {
<[&[u8]] as Type<Postgres>>::type_info()
Expand Down Expand Up @@ -60,6 +68,12 @@ impl<const N: usize> Encode<'_, Postgres> for [u8; N] {
}
}

impl Encode<'_, Postgres> for Arc<[u8]> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<Postgres>>::encode(self, buf)
}
}

impl<'r> Decode<'r, Postgres> for &'r [u8] {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
Expand Down Expand Up @@ -110,3 +124,12 @@ impl<const N: usize> Decode<'_, Postgres> for [u8; N] {
Ok(bytes)
}
}

impl Decode<'_, Postgres> for Arc<[u8]> {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => value.as_bytes()?.into(),
PgValueFormat::Text => hex::decode(text_hex_decode_input(value)?)?.into(),
})
}
}
33 changes: 33 additions & 0 deletions sqlx-postgres/src/types/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::types::array_compatible;
use crate::types::Type;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres};
use std::borrow::Cow;
use std::sync::Arc;

impl Type<Postgres> for str {
fn type_info() -> PgTypeInfo {
Expand Down Expand Up @@ -54,6 +55,16 @@ impl Type<Postgres> for String {
}
}

impl Type<Postgres> for Arc<str> {
fn type_info() -> PgTypeInfo {
<&str as Type<Postgres>>::type_info()
}

fn compatible(ty: &PgTypeInfo) -> bool {
<&str as Type<Postgres>>::compatible(ty)
}
}

impl PgHasArrayType for &'_ str {
fn array_type_info() -> PgTypeInfo {
PgTypeInfo::TEXT_ARRAY
Expand Down Expand Up @@ -94,6 +105,16 @@ impl PgHasArrayType for String {
}
}

impl PgHasArrayType for Arc<str> {
fn array_type_info() -> PgTypeInfo {
<&str as PgHasArrayType>::array_type_info()
}

fn array_compatible(ty: &PgTypeInfo) -> bool {
<&str as PgHasArrayType>::array_compatible(ty)
}
}

impl Encode<'_, Postgres> for &'_ str {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
buf.extend(self.as_bytes());
Expand Down Expand Up @@ -123,6 +144,12 @@ impl Encode<'_, Postgres> for String {
}
}

impl Encode<'_, Postgres> for Arc<str> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<&str as Encode<Postgres>>::encode(&**self, buf)
}
}

impl<'r> Decode<'r, Postgres> for &'r str {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
value.as_str()
Expand All @@ -146,3 +173,9 @@ impl Decode<'_, Postgres> for String {
Ok(value.as_str()?.to_owned())
}
}

impl Decode<'_, Postgres> for Arc<str> {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(value.as_str()?.into())
}
}
34 changes: 34 additions & 0 deletions sqlx-sqlite/src/types/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand Down Expand Up @@ -101,3 +102,36 @@ impl<'r> Decode<'r, Sqlite> for Vec<u8> {
Ok(value.blob().to_owned())
}
}

impl Type<Sqlite> for Arc<[u8]> {
fn type_info() -> SqliteTypeInfo {
<&[u8] as Type<Sqlite>>::type_info()
}

fn compatible(ty: &SqliteTypeInfo) -> bool {
<&[u8] as Type<Sqlite>>::compatible(ty)
}
}

impl<'q> Encode<'q, Sqlite> for Arc<[u8]> {
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.to_vec())));

Ok(IsNull::No)
}

fn encode_by_ref(
&self,
args: &mut Vec<SqliteArgumentValue<'q>>,
) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Blob(Cow::Owned(self.to_vec())));

Ok(IsNull::No)
}
}

impl<'r> Decode<'r, Sqlite> for Arc<[u8]> {
fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(value.blob().into())
}
}
30 changes: 30 additions & 0 deletions sqlx-sqlite/src/types/str.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::sync::Arc;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand Down Expand Up @@ -122,3 +123,32 @@ impl<'r> Decode<'r, Sqlite> for Cow<'r, str> {
value.text().map(Cow::Borrowed)
}
}

impl Type<Sqlite> for Arc<str> {
fn type_info() -> SqliteTypeInfo {
<&str as Type<Sqlite>>::type_info()
}
}

impl<'q> Encode<'q, Sqlite> for Arc<str> {
fn encode(self, args: &mut Vec<SqliteArgumentValue<'q>>) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Text(Cow::Owned(self.to_string())));

Ok(IsNull::No)
}

fn encode_by_ref(
&self,
args: &mut Vec<SqliteArgumentValue<'q>>,
) -> Result<IsNull, BoxDynError> {
args.push(SqliteArgumentValue::Text(Cow::Owned(self.to_string())));

Ok(IsNull::No)
}
}

impl<'r> Decode<'r, Sqlite> for Arc<str> {
fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
value.text().map(Into::into)
}
}
31 changes: 31 additions & 0 deletions tests/mysql/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ extern crate time_ as time;
use std::net::SocketAddr;
#[cfg(feature = "rust_decimal")]
use std::str::FromStr;
use std::sync::Arc;

use sqlx::mysql::MySql;
use sqlx::{Executor, Row};
Expand Down Expand Up @@ -384,3 +385,33 @@ CREATE TEMPORARY TABLE user_login (

Ok(())
}

#[sqlx_macros::test]
async fn test_arc_str() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

let name: Arc<str> = "Harold".into();

let username: Arc<str> = sqlx::query_scalar("SELECT ? AS username")
.bind(&name)
.fetch_one(&mut conn)
.await?;

assert!(username == name);
Ok(())
}

#[sqlx_macros::test]
async fn test_arc_slice() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

let name: Arc<[u8]> = [5, 0].into();

let username: Arc<[u8]> = sqlx::query_scalar("SELECT ?")
.bind(&name)
.fetch_one(&mut conn)
.await?;

assert!(username == name);
Ok(())
}
Loading

0 comments on commit 5bf1fdd

Please sign in to comment.