Skip to content

Commit

Permalink
[SPARK-4536][SQL] Add sqrt and abs to Spark SQL DSL
Browse files Browse the repository at this point in the history
Spark SQL has embeded sqrt and abs but DSL doesn't support those functions.

Author: Kousuke Saruta <[email protected]>

Closes apache#3401 from sarutak/dsl-missing-operator and squashes the following commits:

07700cf [Kousuke Saruta] Modified Literal(null, NullType) to Literal(null) in DslQuerySuite
8f366f8 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into dsl-missing-operator
1b88e2e [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into dsl-missing-operator
0396f89 [Kousuke Saruta] Added sqrt and abs to Spark SQL DSL
  • Loading branch information
sarutak authored and marmbrus committed Dec 2, 2014
1 parent b1f8fe3 commit e75e04f
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ package object dsl {
def max(e: Expression) = Max(e)
def upper(e: Expression) = Upper(e)
def lower(e: Expression) = Lower(e)
def sqrt(e: Expression) = Sqrt(e)
def abs(e: Expression) = Abs(e)

implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
// TODO more implicit class for literal?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.types._
import scala.math.pow

case class UnaryMinus(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
Expand Down
68 changes: 68 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,4 +282,72 @@ class DslQuerySuite extends QueryTest {
(1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
)
}

test("sqrt") {
checkAnswer(
testData.select(sqrt('key)).orderBy('key asc),
(1 to 100).map(n => Seq(math.sqrt(n)))
)

checkAnswer(
testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc),
(1 to 100).map(n => Seq(math.sqrt(n), n))
)

checkAnswer(
testData.select(sqrt(Literal(null))),
(1 to 100).map(_ => Seq(null))
)
}

test("abs") {
checkAnswer(
testData.select(abs('key)).orderBy('key asc),
(1 to 100).map(n => Seq(n))
)

checkAnswer(
negativeData.select(abs('key)).orderBy('key desc),
(1 to 100).map(n => Seq(n))
)

checkAnswer(
testData.select(abs(Literal(null))),
(1 to 100).map(_ => Seq(null))
)
}

test("upper") {
checkAnswer(
lowerCaseData.select(upper('l)),
('a' to 'd').map(c => Seq(c.toString.toUpperCase()))
)

checkAnswer(
testData.select(upper('value), 'key),
(1 to 100).map(n => Seq(n.toString, n))
)

checkAnswer(
testData.select(upper(Literal(null))),
(1 to 100).map(n => Seq(null))
)
}

test("lower") {
checkAnswer(
upperCaseData.select(lower('L)),
('A' to 'F').map(c => Seq(c.toString.toLowerCase()))
)

checkAnswer(
testData.select(lower('value), 'key),
(1 to 100).map(n => Seq(n.toString, n))
)

checkAnswer(
testData.select(lower(Literal(null))),
(1 to 100).map(n => Seq(null))
)
}
}
4 changes: 4 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ object TestData {
(1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
testData.registerTempTable("testData")

val negativeData = TestSQLContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD
negativeData.registerTempTable("negativeData")

case class LargeAndSmallInts(a: Int, b: Int)
val largeAndSmallInts =
TestSQLContext.sparkContext.parallelize(
Expand Down

0 comments on commit e75e04f

Please sign in to comment.