diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index 517e7f1863c7..de2b04ba34d3 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -610,3 +610,61 @@ async fn pi_function() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +macro_rules! assert_logical_plan { + ($ctx:expr, $udf_call:expr, $expected:expr) => {{ + let sql = format!("SELECT {}", $udf_call); + let logical_plan = $ctx.create_logical_plan(&sql)?; + let formatted = format!("{:?}", logical_plan); + assert_eq!($expected, formatted.split("\n").collect::>()); + }}; +} + +#[tokio::test] +async fn test_log_round_logical_plan() -> Result<()> { + let ctx = SessionContext::new(); + + assert_logical_plan!( + ctx, + "log(2.0, 2)", + vec!["Projection: log(Float64(2), Int64(2))", " EmptyRelation",] + ); + + assert_logical_plan!( + ctx, + "log(2::Decimal(38,10), 2::Decimal(38,10))", + vec![ + "Projection: log(CAST(Int64(2) AS Decimal(38, 10)), CAST(Int64(2) AS Decimal(38, 10)))", + " EmptyRelation", + ] + ); + + assert_logical_plan!( + ctx, + "log(2::Decimal(38,10), 2)", + vec![ + "Projection: log(CAST(Int64(2) AS Decimal(38, 10)), Int64(2))", + " EmptyRelation", + ] + ); + + assert_logical_plan!( + ctx, + "round(5.7, 2)", + vec![ + "Projection: round(Float64(5.7), Int64(2))", + " EmptyRelation", + ] + ); + + assert_logical_plan!( + ctx, + "round(5.7::Decimal(38,10), 2)", + vec![ + "Projection: round(CAST(Float64(5.7) AS Decimal(38, 10)), Int64(2))", + " EmptyRelation", + ] + ); + + Ok(()) +} diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index e6cdfa428f7b..a3dad8f70dfc 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -559,6 +559,8 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { DataType::Decimal(38, 10), DataType::Decimal(38, 10), ]), + TypeSignature::Exact(vec![DataType::Decimal(38, 10), DataType::Int64]), + TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]), ], fun.volatility(), ), @@ -567,6 +569,8 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { TypeSignature::Exact(vec![DataType::Float64]), TypeSignature::Exact(vec![DataType::Float32]), // NOTE: stub, won't execute + TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Decimal(38, 10), DataType::Int64]), TypeSignature::Exact(vec![DataType::Decimal(38, 10), DataType::Int32]), ], fun.volatility(),