Skip to content

Commit

Permalink
add more unit tests and window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed May 21, 2015
1 parent 64e18a7 commit 964c013
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 8 deletions.
16 changes: 15 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr)

/**
* Create a new [[WindowFunctionDefinition]] bundled with this column(expression).
* Create a new [[WindowFunctionDefinition]] bundled with this column.
* {{{
* df.select(avg($"value").over...)
* }}}
Expand All @@ -899,6 +899,20 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*/
def over: WindowFunctionDefinition = new WindowFunctionDefinition(this)

/**
* Reuse an existed [[WindowFunctionDefinition]] and bundled with this column.
* {{{
* val w = over.partitionBy("name").orderBy("id")
* df.select(
* sum("price").over(w).between.preceding(2),
* avg("price").over(w).between.preceding(4)
* )
* }}}
*
* @group expr_ops
*/
def over(w: WindowFunctionDefinition) = w.newColumn(this)

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ import org.apache.spark.sql.catalyst.expressions._
*/
@Experimental
class WindowFunctionDefinition protected[sql](
column: Column,
column: Column = null,
partitionSpec: Seq[Expression] = Nil,
orderSpec: Seq[SortOrder] = Nil,
frame: WindowFrame = UnspecifiedFrame) {

private[sql] def newColumn(c: Column): WindowFunctionDefinition = {
new WindowFunctionDefinition(c, partitionSpec, orderSpec, frame)
}

/**
* Returns a new [[WindowFunctionDefinition]] partitioned by the specified column.
* {{{
Expand Down Expand Up @@ -218,6 +222,9 @@ class WindowFunctionDefinition protected[sql](
* @group window_funcs
*/
def toColumn: Column = {
if (column == null) {
throw new AnalysisException("Window didn't bind with expression")
}
val windowExpr = column.expr match {
case Average(child) => WindowExpression(
UnresolvedWindowFunction("avg", child :: Nil),
Expand Down
128 changes: 128 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,132 @@ object functions {
*/
def max(columnName: String): Column = max(Column(columnName))

//////////////////////////////////////////////////////////////////////////////////////////////
// Window functions
//////////////////////////////////////////////////////////////////////////////////////////////

/**
* Window function: returns the lag value of current row of the expression,
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
*/
def lag(columnName: String): Column = {
lag(columnName, 1)
}

/**
* Window function: returns the lag value of current row of the column,
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
*/
def lag(e: Column): Column = {
lag(e, 1)
}

/**
* Window function: returns the lag values of current row of the expression,
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
*/
def lag(e: Column, count: Int): Column = {
lag(e, count, null)
}

/**
* Window function: returns the lag values of current row of the column,
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
*/
def lag(columnName: String, count: Int): Column = {
lag(columnName, count, null)
}

/**
* Window function: returns the lag values of current row of the column,
* given default value when the current row extends before the beginning
* of the window.
*
* @group window_funcs
*/
def lag(columnName: String, count: Int, defaultValue: Any): Column = {
lag(Column(columnName), count, defaultValue)
}

/**
* Window function: returns the lag values of current row of the expression,
* given default value when the current row extends before the beginning
* of the window.
*
* @group window_funcs
*/
def lag(e: Column, count: Int, defaultValue: Any): Column = {
UnresolvedWindowFunction("lag", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil)
}

/**
* Window function: returns the lead value of current row of the column,
* null when the current row extends before the end of the window.
*
* @group window_funcs
*/
def lead(columnName: String): Column = {
lead(columnName, 1)
}

/**
* Window function: returns the lead value of current row of the expression,
* null when the current row extends before the end of the window.
*
* @group window_funcs
*/
def lead(e: Column): Column = {
lead(e, 1)
}

/**
* Window function: returns the lead values of current row of the column,
* null when the current row extends before the end of the window.
*
* @group window_funcs
*/
def lead(columnName: String, count: Int): Column = {
lead(columnName, count, null)
}

/**
* Window function: returns the lead values of current row of the expression,
* null when the current row extends before the end of the window.
*
* @group window_funcs
*/
def lead(e: Column, count: Int): Column = {
lead(e, count, null)
}

/**
* Window function: returns the lead values of current row of the column,
* given default value when the current row extends before the end of the window.
*
* @group window_funcs
*/
def lead(columnName: String, count: Int, defaultValue: Any): Column = {
lead(Column(columnName), count, defaultValue)
}

/**
* Window function: returns the lead values of current row of the expression,
* given default value when the current row extends before the end of the window.
*
* @group window_funcs
*/
def lead(e: Column, count: Int, defaultValue: Any): Column = {
UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil)
}

//////////////////////////////////////////////////////////////////////////////////////////////
// Non-aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1393,4 +1519,6 @@ object functions {
UnresolvedFunction(udfName, cols.map(_.expr))
}

def over: WindowFunctionDefinition = new WindowFunctionDefinition()

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,67 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._

class HiveDataFrameWindowSuite extends QueryTest {

test("reuse window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
val w = over.partitionBy("key").orderBy("value")

checkAnswer(
df.select(
lead("key").over(w).toColumn,
lead("value").over(w).toColumn),
Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
}

test("lead in window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
lead("value").over
.partitionBy($"key")
.orderBy($"value")
.toColumn),
Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil)
}

test("lag in window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
lead("value").over
.partitionBy($"key")
.orderBy($"value")
.toColumn),
Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil)
}

test("lead in window with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
lead("value", 2, "n/a").over
.partitionBy("key")
.orderBy("value")
.toColumn),
Row("1") :: Row("1") :: Row("2") :: Row("n/a")
:: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil)
}

test("lag in window with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
lead("value", 2, "n/a").over
.partitionBy($"key")
.orderBy($"value")
.toColumn),
Row("1") :: Row("1") :: Row("2") :: Row("n/a")
:: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil)
}

test("aggregation in a Row window") {
val df = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
avg("key").over
Expand All @@ -33,11 +92,11 @@ class HiveDataFrameWindowSuite extends QueryTest {
.rows
.preceding(1)
.toColumn),
Row(1.0) :: Row(2.0) :: Nil)
Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil)
}

test("aggregation in a Range window") {
val df = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
avg("key").over
Expand All @@ -47,11 +106,11 @@ class HiveDataFrameWindowSuite extends QueryTest {
.preceding(1)
.following(1)
.toColumn),
Row(1.0) :: Row(2.0) :: Nil)
Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil)
}

test("multiple aggregate function in window") {
val df = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
avg("key").over
Expand All @@ -66,7 +125,7 @@ class HiveDataFrameWindowSuite extends QueryTest {
.preceding(1)
.following(1)
.toColumn),
Row(1, 1.0) :: Row(2, 2.0) :: Nil)
Row(1.0, 1) :: Row(1.0, 2) :: Row(2.0, 2) :: Row(2.0, 4) :: Nil)
}

test("Window function in Unspecified Window") {
Expand Down

0 comments on commit 964c013

Please sign in to comment.