diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f7f5b956b5461..0132c93ba2c3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -890,7 +890,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) /** - * Create a new [[WindowFunctionDefinition]] bundled with this column(expression). + * Create a new [[WindowFunctionDefinition]] bundled with this column. * {{{ * df.select(avg($"value").over...) * }}} @@ -899,6 +899,20 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def over: WindowFunctionDefinition = new WindowFunctionDefinition(this) + /** + * Reuse an existed [[WindowFunctionDefinition]] and bundled with this column. + * {{{ + * val w = over.partitionBy("name").orderBy("id") + * df.select( + * sum("price").over(w).between.preceding(2), + * avg("price").over(w).between.preceding(4) + * ) + * }}} + * + * @group expr_ops + */ + def over(w: WindowFunctionDefinition) = w.newColumn(this) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala index bfabb5e2b03ec..7bdd0daebf597 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/WindowFunctionDefinition.scala @@ -49,11 +49,15 @@ import org.apache.spark.sql.catalyst.expressions._ */ @Experimental class WindowFunctionDefinition protected[sql]( - column: Column, + column: Column = null, partitionSpec: Seq[Expression] = Nil, orderSpec: Seq[SortOrder] = Nil, frame: WindowFrame = UnspecifiedFrame) { + private[sql] def newColumn(c: Column): WindowFunctionDefinition = { + new WindowFunctionDefinition(c, partitionSpec, orderSpec, frame) + } + /** * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. * {{{ @@ -218,6 +222,9 @@ class WindowFunctionDefinition protected[sql]( * @group window_funcs */ def toColumn: Column = { + if (column == null) { + throw new AnalysisException("Window didn't bind with expression") + } val windowExpr = column.expr match { case Average(child) => WindowExpression( UnresolvedWindowFunction("avg", child :: Nil), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6640631cf0719..5e6947ae4667c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -320,6 +320,132 @@ object functions { */ def max(columnName: String): Column = max(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Window functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Window function: returns the lag value of current row of the expression, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(columnName: String): Column = { + lag(columnName, 1) + } + + /** + * Window function: returns the lag value of current row of the column, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(e: Column): Column = { + lag(e, 1) + } + + /** + * Window function: returns the lag values of current row of the expression, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(e: Column, count: Int): Column = { + lag(e, count, null) + } + + /** + * Window function: returns the lag values of current row of the column, + * null when the current row extends before the beginning of the window. + * + * @group window_funcs + */ + def lag(columnName: String, count: Int): Column = { + lag(columnName, count, null) + } + + /** + * Window function: returns the lag values of current row of the column, + * given default value when the current row extends before the beginning + * of the window. + * + * @group window_funcs + */ + def lag(columnName: String, count: Int, defaultValue: Any): Column = { + lag(Column(columnName), count, defaultValue) + } + + /** + * Window function: returns the lag values of current row of the expression, + * given default value when the current row extends before the beginning + * of the window. + * + * @group window_funcs + */ + def lag(e: Column, count: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lag", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the lead value of current row of the column, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(columnName: String): Column = { + lead(columnName, 1) + } + + /** + * Window function: returns the lead value of current row of the expression, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(e: Column): Column = { + lead(e, 1) + } + + /** + * Window function: returns the lead values of current row of the column, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(columnName: String, count: Int): Column = { + lead(columnName, count, null) + } + + /** + * Window function: returns the lead values of current row of the expression, + * null when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(e: Column, count: Int): Column = { + lead(e, count, null) + } + + /** + * Window function: returns the lead values of current row of the column, + * given default value when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(columnName: String, count: Int, defaultValue: Any): Column = { + lead(Column(columnName), count, defaultValue) + } + + /** + * Window function: returns the lead values of current row of the expression, + * given default value when the current row extends before the end of the window. + * + * @group window_funcs + */ + def lead(e: Column, count: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1393,4 +1519,6 @@ object functions { UnresolvedFunction(udfName, cols.map(_.expr)) } + def over: WindowFunctionDefinition = new WindowFunctionDefinition() + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 62c9ed95cb5bd..c116c5fa39eca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -23,8 +23,67 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ class HiveDataFrameWindowSuite extends QueryTest { + test("reuse window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = over.partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key").over(w).toColumn, + lead("value").over(w).toColumn), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("lead in window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + lead("value").over + .partitionBy($"key") + .orderBy($"value") + .toColumn), + Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil) + } + + test("lag in window") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + lead("value").over + .partitionBy($"key") + .orderBy($"value") + .toColumn), + Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil) + } + + test("lead in window with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + lead("value", 2, "n/a").over + .partitionBy("key") + .orderBy("value") + .toColumn), + Row("1") :: Row("1") :: Row("2") :: Row("n/a") + :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil) + } + + test("lag in window with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + checkAnswer( + df.select( + lead("value", 2, "n/a").over + .partitionBy($"key") + .orderBy($"value") + .toColumn), + Row("1") :: Row("1") :: Row("2") :: Row("n/a") + :: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil) + } + test("aggregation in a Row window") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( avg("key").over @@ -33,11 +92,11 @@ class HiveDataFrameWindowSuite extends QueryTest { .rows .preceding(1) .toColumn), - Row(1.0) :: Row(2.0) :: Nil) + Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil) } test("aggregation in a Range window") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( avg("key").over @@ -47,11 +106,11 @@ class HiveDataFrameWindowSuite extends QueryTest { .preceding(1) .following(1) .toColumn), - Row(1.0) :: Row(2.0) :: Nil) + Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil) } test("multiple aggregate function in window") { - val df = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") checkAnswer( df.select( avg("key").over @@ -66,7 +125,7 @@ class HiveDataFrameWindowSuite extends QueryTest { .preceding(1) .following(1) .toColumn), - Row(1, 1.0) :: Row(2, 2.0) :: Nil) + Row(1.0, 1) :: Row(1.0, 2) :: Row(2.0, 2) :: Row(2.0, 4) :: Nil) } test("Window function in Unspecified Window") {