diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5e6947ae4667c..c21cbd95ed67e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -446,6 +446,61 @@ object functions { UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil) } + /** + * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. + * For example: + * {{{ + * // The following 2 are equivalent + * partitionBy("k1", "k2").orderBy("k3") + * partitionBy($"K1", $"k2").orderBy($"k3") + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowFunctionDefinition = { + new WindowFunctionDefinition().partitionBy(colName, colNames: _*) + } + + /** + * Returns a new [[WindowFunctionDefinition]] partitioned by the specified column. + * For example: + * {{{ + * partitionBy($"col1", $"col2").orderBy("value") + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowFunctionDefinition = { + new WindowFunctionDefinition().partitionBy(cols: _*) + } + + /** + * Create a new [[WindowFunctionDefinition]] sorted by the specified columns. + * For example: + * {{{ + * // The following 2 are equivalent + * orderBy("k2", "k3").partitionBy("k1") + * orderBy($"k2", $"k3").partitionBy("k1") + * }}} + * @group window_funcs + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowFunctionDefinition = { + new WindowFunctionDefinition().orderBy(colName, colNames: _*) + } + + /** + * Returns a new [[WindowFunctionDefinition]] sorted by the specified columns. + * For example + * {{{ + * val w = orderBy($"k2", $"k3").partitionBy("k1") + * }}} + * @group window_funcs + */ + def orderBy(cols: Column*): WindowFunctionDefinition = { + new WindowFunctionDefinition().orderBy(cols: _*) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1519,6 +1574,4 @@ object functions { UnresolvedFunction(udfName, cols.map(_.expr)) } - def over: WindowFunctionDefinition = new WindowFunctionDefinition() - } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index c116c5fa39eca..43fadc525f533 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -23,9 +23,20 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._ class HiveDataFrameWindowSuite extends QueryTest { - test("reuse window") { + test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - val w = over.partitionBy("key").orderBy("value") + val w = partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key").over(w).toColumn, + lead("value").over(w).toColumn), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("reuse window orderBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = orderBy("value").partitionBy("key") checkAnswer( df.select(