Skip to content

Commit

Permalink
Address code review feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed May 22, 2015
1 parent dc448fe commit 026d587
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ class WindowSpec private[sql](
case x if x > 0 => ValueFollowing(start.toInt)
}

val boundaryEnd = start match {
val boundaryEnd = end match {
case 0 => CurrentRow
case Long.MinValue => UnboundedFollowing
case x if x < 0 => ValuePreceding(-start.toInt)
case x if x > 0 => ValueFollowing(start.toInt)
case Long.MaxValue => UnboundedFollowing
case x if x < 0 => ValuePreceding(-end.toInt)
case x if x > 0 => ValueFollowing(end.toInt)
}

new WindowSpec(
Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ object functions {
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(columnName: String): Column = {
lag(columnName, 1)
Expand All @@ -340,6 +341,7 @@ object functions {
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(e: Column): Column = {
lag(e, 1)
Expand All @@ -350,6 +352,7 @@ object functions {
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(e: Column, count: Int): Column = {
lag(e, count, null)
Expand All @@ -360,6 +363,7 @@ object functions {
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(columnName: String, count: Int): Column = {
lag(columnName, count, null)
Expand All @@ -371,6 +375,7 @@ object functions {
* of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(columnName: String, count: Int, defaultValue: Any): Column = {
lag(Column(columnName), count, defaultValue)
Expand All @@ -382,6 +387,7 @@ object functions {
* of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(e: Column, count: Int, defaultValue: Any): Column = {
UnresolvedWindowFunction("lag", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil)
Expand All @@ -392,6 +398,7 @@ object functions {
* null when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(columnName: String): Column = {
lead(columnName, 1)
Expand All @@ -402,6 +409,7 @@ object functions {
* null when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(e: Column): Column = {
lead(e, 1)
Expand All @@ -412,6 +420,7 @@ object functions {
* null when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(columnName: String, count: Int): Column = {
lead(columnName, count, null)
Expand All @@ -422,6 +431,7 @@ object functions {
* null when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(e: Column, count: Int): Column = {
lead(e, count, null)
Expand All @@ -432,6 +442,7 @@ object functions {
* given default value when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(columnName: String, count: Int, defaultValue: Any): Column = {
lead(Column(columnName), count, defaultValue)
Expand All @@ -442,6 +453,7 @@ object functions {
* given default value when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(e: Column, count: Int, defaultValue: Any): Column = {
UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil)
Expand All @@ -454,6 +466,7 @@ object functions {
* number of groups called buckets and assigns a bucket number to each row in the partition.
*
* @group window_funcs
* @since 1.4.0
*/
def ntile(e: Column): Column = {
UnresolvedWindowFunction("ntile", e.expr :: Nil)
Expand All @@ -466,6 +479,7 @@ object functions {
* number of groups called buckets and assigns a bucket number to each row in the partition.
*
* @group window_funcs
* @since 1.4.0
*/
def ntile(columnName: String): Column = {
ntile(Column(columnName))
Expand All @@ -476,6 +490,7 @@ object functions {
* row within the partition.
*
* @group window_funcs
* @since 1.4.0
*/
def rowNumber(): Column = {
UnresolvedWindowFunction("row_number", Nil)
Expand All @@ -488,6 +503,7 @@ object functions {
* place and that the next person came in third.
*
* @group window_funcs
* @since 1.4.0
*/
def denseRank(): Column = {
UnresolvedWindowFunction("dense_rank", Nil)
Expand All @@ -500,6 +516,7 @@ object functions {
* place and that the next person came in third.
*
* @group window_funcs
* @since 1.4.0
*/
def rank(): Column = {
UnresolvedWindowFunction("rank", Nil)
Expand All @@ -512,6 +529,7 @@ object functions {
* CUME_DIST(x) = number of values in S coming before and including x in the specified order / N
*
* @group window_funcs
* @since 1.4.0
*/
def cumeDist(): Column = {
UnresolvedWindowFunction("cume_dist", Nil)
Expand All @@ -524,6 +542,7 @@ object functions {
* (rank of row in its partition - 1) / (number of rows in the partition - 1)
*
* @group window_funcs
* @since 1.4.0
*/
def percentRank(): Column = {
UnresolvedWindowFunction("percent_rank", Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class HiveDataFrameWindowSuite extends QueryTest {
Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
}

test("lead in window") {
test("lead") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")

Expand All @@ -60,7 +60,7 @@ class HiveDataFrameWindowSuite extends QueryTest {
| FROM window_table""".stripMargin).collect())
}

test("lag in window") {
test("lag") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")

Expand All @@ -75,7 +75,7 @@ class HiveDataFrameWindowSuite extends QueryTest {
| FROM window_table""".stripMargin).collect())
}

test("lead in window with default value") {
test("lead with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
Expand All @@ -88,7 +88,7 @@ class HiveDataFrameWindowSuite extends QueryTest {
| FROM window_table""".stripMargin).collect())
}

test("lag in window with default value") {
test("lag with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
Expand Down Expand Up @@ -134,23 +134,23 @@ class HiveDataFrameWindowSuite extends QueryTest {
|rank() over (partition by value order by key),
|cume_dist() over (partition by value order by key),
|percent_rank() over (partition by value order by key)
|FROM window_table""".stripMargin).collect)
|FROM window_table""".stripMargin).collect())
}

test("aggregation in a row window") {
test("aggregation and rows between") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))),
avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))),
sql(
"""SELECT
| avg(key) OVER
| (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 1 following)
| (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following)
| FROM window_table""".stripMargin).collect())
}

test("aggregation in a Range window") {
test("aggregation and range betweens") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
Expand All @@ -163,25 +163,7 @@ class HiveDataFrameWindowSuite extends QueryTest {
| FROM window_table""".stripMargin).collect())
}

test("Aggregate function in Row preceding Window") {
val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "3"), (4, "3")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
first("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 0)),
first("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-2, 1))),
sql(
"""SELECT
| key,
| first_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS 1 preceding),
| first_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS between 2 preceding and 1 preceding)
| FROM window_table""".stripMargin).collect())
}

test("Aggregate function in Row following Window") {
test("aggregation and rows betweens with unbounded") {
val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
Expand All @@ -191,7 +173,7 @@ class HiveDataFrameWindowSuite extends QueryTest {
Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)),
last("value").over(
Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)),
last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))),
last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))),
sql(
"""SELECT
| key,
Expand All @@ -204,26 +186,7 @@ class HiveDataFrameWindowSuite extends QueryTest {
| FROM window_table""".stripMargin).collect())
}

test("Multiple aggregate functions in row window") {
val df = Seq((1, "1"), (1, "2"), (3, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(-1, 0)),
avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(0, 0)),
avg("key").over(Window.partitionBy($"key").orderBy($"value").rowsBetween(-2, 1))),
sql(
"""SELECT
| avg(key) OVER
| (partition by key ORDER BY value rows 1 preceding),
| avg(key) OVER
| (partition by key ORDER BY value rows between current row and current row),
| avg(key) OVER
| (partition by key ORDER BY value rows between 2 preceding and 1 preceding)
| FROM window_table""".stripMargin).collect())
}

test("Multiple aggregate functions in range window") {
test("aggregation and range betweens with unbounded") {
val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
Expand All @@ -233,9 +196,9 @@ class HiveDataFrameWindowSuite extends QueryTest {
Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue))
.equalTo("2")
.as("last_v"),
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-2, 1))
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1))
.as("avg_key1"),
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, 1))
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue))
.as("avg_key2"),
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0))
.as("avg_key3")
Expand All @@ -246,9 +209,9 @@ class HiveDataFrameWindowSuite extends QueryTest {
| last_value(value) OVER
| (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2",
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 following),
| (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following),
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN current row and 1 following),
| (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following),
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row)
| FROM window_table""".stripMargin).collect())
Expand Down

0 comments on commit 026d587

Please sign in to comment.