diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d9a715fd645f1..6895aa1010956 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -890,19 +890,20 @@ class Column(protected[sql] val expr: Expression) extends Logging { def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) /** - * Define a [[Window]] column. + * Define a windowing column. + * * {{{ * val w = Window.partitionBy("name").orderBy("id") * df.select( - * sum("price").over(w.range.preceding(2)), - * avg("price").over(w.range.preceding(4)), - * avg("price").over(partitionBy("name").orderBy("id).range.preceding(1)) + * sum("price").over(w.rangeBetween(Long.MinValue, 2)), + * avg("price").over(w.rowsBetween(0, 4)) * ) * }}} * * @group expr_ops + * @since 1.4.0 */ - def over(w: Window): Column = w.newColumn(this).toColumn + def over(window: expressions.WindowSpec): Column = window.withAggregate(this) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d78b4c2f8909c..3ec1c4a2f1027 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, Unresol import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.json.JacksonGenerator import org.apache.spark.sql.sources.CreateTableUsingAsSelect @@ -411,7 +411,7 @@ class DataFrame private[sql]( joined.left, joined.right, joinType = Inner, - Some(expressions.EqualTo( + Some(catalyst.expressions.EqualTo( joined.left.resolve(usingColumn), joined.right.resolve(usingColumn)))) ) @@ -480,8 +480,9 @@ class DataFrame private[sql]( // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. val cond = plan.condition.map { _.transform { - case expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) + case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) }} plan.copy(condition = cond) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/Window.scala deleted file mode 100644 index 80272f380b117..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/Window.scala +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.language.implicitConversions - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.expressions._ - - -sealed private[sql] class Frame(private[sql] var boundary: FrameBoundary = null) - -/** - * :: Experimental :: - * An utility to specify the Window Frame Range. - */ -object Frame { - val currentRow: Frame = new Frame(CurrentRow) - val unbounded: Frame = new Frame() - def preceding(n: Int): Frame = if (n == 0) { - new Frame(CurrentRow) - } else { - new Frame(ValuePreceding(n)) - } - - def following(n: Int): Frame = if (n == 0) { - new Frame(CurrentRow) - } else { - new Frame(ValueFollowing(n)) - } -} - -/** - * :: Experimental :: - * A Window object with everything unset. But can build new Window object - * based on it. - */ -@Experimental -object Window extends Window() - -/** - * :: Experimental :: - * A set of methods for window function definition for aggregate expressions. - * For example: - * {{{ - * // predefine a window - * val w = Window.partitionBy("name").orderBy("id") - * .rowsBetween(Frame.unbounded, Frame.currentRow) - * df.select( - * avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..") - * .rowsBetween(Frame.unbounded, Frame.currentRow)) - * ) - * - * df.select( - * avg("age").over(Window.partitionBy("..", "..").orderBy("..", "..") - * .rowsBetween(Frame.preceding(50), Frame.following(10))) - * ) - * - * }}} - * - */ -@Experimental -class Window { - private var column: Column = _ - private var partitionSpec: Seq[Expression] = Nil - private var orderSpec: Seq[SortOrder] = Nil - private var frame: WindowFrame = UnspecifiedFrame - - private def this( - column: Column = null, - partitionSpec: Seq[Expression] = Nil, - orderSpec: Seq[SortOrder] = Nil, - frame: WindowFrame = UnspecifiedFrame) { - this() - this.column = column - this.partitionSpec = partitionSpec - this.orderSpec = orderSpec - this.frame = frame - } - - private[sql] def newColumn(c: Column): Window = { - new Window(c, partitionSpec, orderSpec, frame) - } - - /** - * Returns a new [[Window]] partitioned by the specified column. - * {{{ - * // The following 2 are equivalent - * df.over(Window.partitionBy("k1", "k2", ...)) - * df.over(Window.partitionBy($"K1", $"k2", ...)) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def partitionBy(colName: String, colNames: String*): Window = { - partitionBy((colName +: colNames).map(Column(_)): _*) - } - - /** - * Returns a new [[Window]] partitioned by the specified column. For example: - * {{{ - * df.over(Window.partitionBy($"col1", $"col2")) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def partitionBy(cols: Column*): Window = { - new Window(column, cols.map(_.expr), orderSpec, frame) - } - - /** - * Returns a new [[Window]] sorted by the specified column within - * the partition. - * {{{ - * // The following 2 are equivalent - * df.over(Window.partitionBy("k1").orderBy("k2", "k3")) - * df.over(Window.partitionBy("k1").orderBy($"k2", $"k3")) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def orderBy(colName: String, colNames: String*): Window = { - orderBy((colName +: colNames).map(Column(_)): _*) - } - - /** - * Returns a new [[Window]] sorted by the specified column within - * the partition. For example - * {{{ - * df.over(Window.partitionBy("k1").orderBy($"k2", $"k3")) - * }}} - * @group window_funcs - */ - @scala.annotation.varargs - def orderBy(cols: Column*): Window = { - val sortOrder: Seq[SortOrder] = cols.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - new Window(column, partitionSpec, sortOrder, frame) - } - - def rowsBetween(start: Frame, end: Frame): Window = { - assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing") - assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding") - - val s = if (start.boundary == null) UnboundedPreceding else start.boundary - val e = if (end.boundary == null) UnboundedFollowing else end.boundary - - new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RowFrame, s, e)) - } - - def rangeBetween(start: Frame, end: Frame): Window = { - assert(start.boundary != UnboundedFollowing, "Start can not be UnboundedFollowing") - assert(end.boundary != UnboundedPreceding, "End can not be UnboundedPreceding") - - val s = if (start.boundary == null) UnboundedPreceding else start.boundary - val e = if (end.boundary == null) UnboundedFollowing else end.boundary - - new Window(column, partitionSpec, orderSpec, SpecifiedWindowFrame(RangeFrame, s, e)) - } - - /** - * Convert the window definition into a Column object. - * @group window_funcs - */ - private[sql] def toColumn: Column = { - if (column == null) { - throw new AnalysisException("Window didn't bind with expression") - } - val windowExpr = column.expr match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction("first_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction("last_value", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case wf: WindowFunction => WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case x => - throw new UnsupportedOperationException(s"We don't support $x in window operation.") - } - new Column(windowExpr) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala new file mode 100644 index 0000000000000..d4003b2d9cbf6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions._ + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +object Window { + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + spec.partitionBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + spec.partitionBy(cols : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + spec.orderBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + spec.orderBy(cols : _*) + } + + private def spec: WindowSpec = { + new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala new file mode 100644 index 0000000000000..00ecdb47ca5a9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{Column, catalyst} +import org.apache.spark.sql.catalyst.expressions._ + + +/** + * :: Experimental :: + * A window specification that defines the partitioning, ordering, and frame boundaries. + * + * Use the static methods in [[Window]] to create a [[WindowSpec]]. + * + * @since 1.4.0 + */ +@Experimental +class WindowSpec private[sql]( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + frame: catalyst.expressions.WindowFrame) { + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + partitionBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + new WindowSpec(cols.map(_.expr), orderSpec, frame) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + orderBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + val sortOrder: Seq[SortOrder] = cols.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + new WindowSpec(partitionSpec, sortOrder, frame) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative positions from the current row. For example, "0" means + * "current row", while "-1" means the row before the current row, and "5" means the fifth row + * after the current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rowsBetween(start: Long, end: Long): WindowSpec = { + between(RowFrame, start, end) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * while "-1" means one off before the current row, and "5" means the five off after the + * current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rangeBetween(start: Long, end: Long): WindowSpec = { + between(RangeFrame, start, end) + } + + private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if x < 0 => ValuePreceding(-start.toInt) + case x if x > 0 => ValueFollowing(start.toInt) + } + + val boundaryEnd = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedFollowing + case x if x < 0 => ValuePreceding(-start.toInt) + case x if x > 0 => ValueFollowing(start.toInt) + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + } + + /** + * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. + */ + private[sql] def withAggregate(aggregate: Column): Column = { + val windowExpr = aggregate.expr match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction("first_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction("last_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case wf: WindowFunction => WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"$x is not supported in window operation.") + } + new Column(windowExpr) + } + +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 9d0feb4e5bd3d..eeb676d3dc126 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -14,74 +14,64 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.hive; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; -import org.apache.spark.sql.hive.test.TestHive$; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.util.Utils; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.hive.test.TestHive$; public class JavaDataFrameSuite { - private transient JavaSparkContext sc; - private transient HiveContext hc; + private transient JavaSparkContext sc; + private transient HiveContext hc; - DataFrame df; + DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { - String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); - if (errorMessage != null) { - Assert.fail(errorMessage); - } + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); } + } - @Before - public void setUp() throws IOException { - hc = TestHive$.MODULE$; - sc = new JavaSparkContext(hc.sparkContext()); + @Before + public void setUp() throws IOException { + hc = TestHive$.MODULE$; + sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList(10); - for (int i = 0; i < 10; i++) { - jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); - } - df = hc.jsonRDD(sc.parallelize(jsonObjects)); - df.registerTempTable("window_table"); + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } + df = hc.jsonRDD(sc.parallelize(jsonObjects)); + df.registerTempTable("window_table"); + } - @After - public void tearDown() throws IOException { - // Clean up tables. - hc.sql("DROP TABLE IF EXISTS window_table"); - } + @After + public void tearDown() throws IOException { + // Clean up tables. + hc.sql("DROP TABLE IF EXISTS window_table"); + } - @Test - public void saveTableAndQueryIt() { - checkAnswer( - df.select( - functions.avg("key").over( - Window$.MODULE$.partitionBy("value") - .orderBy("key") - .rowsBetween(Frame.preceding(1), Frame.following(1)))), - hc.sql("SELECT avg(key) " + - "OVER (PARTITION BY value " + - " ORDER BY key " + - " ROWS BETWEEN 1 preceding and 1 following) " + - "FROM window_table").collectAsList()); - } + @Test + public void saveTableAndQueryIt() { + checkAnswer( + df.select(functions.avg("key").over( + Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), + hc.sql("SELECT avg(key) " + + "OVER (PARTITION BY value " + + " ORDER BY key " + + " ROWS BETWEEN 1 preceding and 1 following) " + + "FROM window_table").collectAsList()); + } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 58fe96adab17e..ee21caf636671 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.hive; import java.io.File; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index 29661cb8a5080..6fee3bcb17358 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Frame, Window, Row, QueryTest} +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -52,9 +53,7 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( - lead("value").over( - Window.partitionBy($"key") - .orderBy($"value"))), + lead("value").over(Window.partitionBy($"key").orderBy($"value"))), sql( """SELECT | lead(value) OVER (PARTITION BY key ORDER BY value) @@ -76,28 +75,13 @@ class HiveDataFrameWindowSuite extends QueryTest { | FROM window_table""".stripMargin).collect()) } - test("last in window with default value") { - val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), - (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") - checkAnswer( - df.select( - last("value").over(Window)), - sql( - """SELECT - | last_value(value) OVER () - | FROM window_table""".stripMargin).collect()) - } - test("lead in window with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( - lead("value", 2, "n/a").over( - Window.partitionBy("key") - .orderBy("value"))), + lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), sql( """SELECT | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) @@ -110,9 +94,7 @@ class HiveDataFrameWindowSuite extends QueryTest { df.registerTempTable("window_table") checkAnswer( df.select( - lag("value", 2, "n/a").over( - Window.partitionBy($"key") - .orderBy($"value"))), + lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), sql( """SELECT | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) @@ -125,42 +107,18 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( $"key", - max("key").over( - Window.partitionBy("value") - .orderBy("key")), - min("key").over( - Window.partitionBy("value") - .orderBy("key")), - mean("key").over( - Window.partitionBy("value") - .orderBy("key")), - count("key").over( - Window.partitionBy("value") - .orderBy("key")), - sum("key").over( - Window.partitionBy("value") - .orderBy("key")), - ntile("key").over( - Window.partitionBy("value") - .orderBy("key")), - ntile($"key").over( - Window.partitionBy("value") - .orderBy("key")), - rowNumber().over( - Window.partitionBy("value") - .orderBy("key")), - denseRank().over( - Window.partitionBy("value") - .orderBy("key")), - rank().over( - Window.partitionBy("value") - .orderBy("key")), - cumeDist().over( - Window.partitionBy("value") - .orderBy("key")), - percentRank().over( - Window.partitionBy("value") - .orderBy("key"))), + max("key").over(Window.partitionBy("value").orderBy("key")), + min("key").over(Window.partitionBy("value").orderBy("key")), + mean("key").over(Window.partitionBy("value").orderBy("key")), + count("key").over(Window.partitionBy("value").orderBy("key")), + sum("key").over(Window.partitionBy("value").orderBy("key")), + ntile("key").over(Window.partitionBy("value").orderBy("key")), + ntile($"key").over(Window.partitionBy("value").orderBy("key")), + rowNumber().over(Window.partitionBy("value").orderBy("key")), + denseRank().over(Window.partitionBy("value").orderBy("key")), + rank().over(Window.partitionBy("value").orderBy("key")), + cumeDist().over(Window.partitionBy("value").orderBy("key")), + percentRank().over(Window.partitionBy("value").orderBy("key"))), sql( s"""SELECT |key, @@ -184,10 +142,7 @@ class HiveDataFrameWindowSuite extends QueryTest { df.registerTempTable("window_table") checkAnswer( df.select( - avg("key").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.preceding(1), Frame.following(1)))), + avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), sql( """SELECT | avg(key) OVER @@ -200,10 +155,7 @@ class HiveDataFrameWindowSuite extends QueryTest { df.registerTempTable("window_table") checkAnswer( df.select( - avg("key").over( - Window.partitionBy($"value") - .orderBy($"key") - .rangeBetween(Frame.preceding(1), Frame.following(1)))), + avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), sql( """SELECT | avg(key) OVER @@ -217,14 +169,8 @@ class HiveDataFrameWindowSuite extends QueryTest { checkAnswer( df.select( $"key", - first("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.preceding(1), Frame.currentRow)), - first("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.preceding(2), Frame.preceding(1)))), + first("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 0)), + first("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-2, 1))), sql( """SELECT | key, @@ -242,17 +188,10 @@ class HiveDataFrameWindowSuite extends QueryTest { df.select( $"key", last("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.currentRow, Frame.unbounded)), + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), last("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.unbounded, Frame.currentRow)), - last("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rowsBetween(Frame.preceding(1), Frame.following(1)))), + Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), sql( """SELECT | key, @@ -270,18 +209,9 @@ class HiveDataFrameWindowSuite extends QueryTest { df.registerTempTable("window_table") checkAnswer( df.select( - avg("key").over( - Window.partitionBy($"key") - .orderBy($"value") - .rowsBetween(Frame.preceding(1), Frame.currentRow)), - avg("key").over( - Window.partitionBy($"key") - .orderBy($"value") - .rowsBetween(Frame.currentRow, Frame.currentRow)), - avg("key").over( - Window.partitionBy($"key") - .orderBy($"value") - .rowsBetween(Frame.preceding(2), Frame.preceding(1)))), + avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(-1, 0)), + avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(0, 0)), + avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(-2, 1))), sql( """SELECT | avg(key) OVER @@ -300,28 +230,14 @@ class HiveDataFrameWindowSuite extends QueryTest { df.select( $"key", last("value").over( - Window.partitionBy($"value") - .orderBy($"key") - .rangeBetween(Frame.following(1), Frame.unbounded)) + Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue)) .equalTo("2") .as("last_v"), - avg("key") - .over( - Window.partitionBy("value") - .orderBy("key") - .rangeBetween(Frame.preceding(2), Frame.following(1))) + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-2, 1)) .as("avg_key1"), - avg("key") - .over( - Window.partitionBy("value") - .orderBy("key") - .rangeBetween(Frame.currentRow, Frame.following(1))) + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, 1)) .as("avg_key2"), - avg("key") - .over( - Window.partitionBy("value") - .orderBy("key") - .rangeBetween(Frame.preceding(1), Frame.currentRow)) + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") ), sql(