Skip to content

Commit

Permalink
polish the unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed May 21, 2015
1 parent 6847825 commit f3fd2d0
Showing 1 changed file with 122 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,54 +48,85 @@ class HiveDataFrameWindowSuite extends QueryTest {

test("lead in window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")

checkAnswer(
df.select(
lead("value").over(
partitionBy($"key")
.orderBy($"value"))),
Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil)
sql(
"""SELECT
| lead(value) OVER (PARTITION BY key ORDER BY value)
| FROM window_table""".stripMargin).collect())
}

test("lag in window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")

checkAnswer(
df.select(
lag("value").over(
partitionBy($"key")
.orderBy($"value"))),
Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil)
sql(
"""SELECT
| lag(value) OVER (PARTITION BY key ORDER BY value)
| FROM window_table""".stripMargin).collect())
}

test("lead in window with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
lead("value", 2, "n/a").over(
partitionBy("key")
.orderBy("value"))),
Row("1") :: Row("1") :: Row("2") :: Row("n/a")
:: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil)
sql(
"""SELECT
| lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value)
| FROM window_table""".stripMargin).collect())
}

test("lag in window with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
lag("value", 2, "n/a").over(
partitionBy($"key")
.orderBy($"value"))),
Row("1") :: Row("1") :: Row("2") :: Row("n/a")
:: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil)
sql(
"""SELECT
| lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value)
| FROM window_table""".stripMargin).collect())
}

test("rank functions in unspecific window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
max("key").over(
partitionBy("value")
.orderBy("key")),
min("key").over(
partitionBy("value")
.orderBy("key")),
mean("key").over(
partitionBy("value")
.orderBy("key")),
count("key").over(
partitionBy("value")
.orderBy("key")),
sum("key").over(
partitionBy("value")
.orderBy("key")),
ntile("key").over(
partitionBy("value")
.orderBy("key")),
Expand All @@ -120,6 +151,11 @@ class HiveDataFrameWindowSuite extends QueryTest {
sql(
s"""SELECT
|key,
|max(key) over (partition by value order by key),
|min(key) over (partition by value order by key),
|avg(key) over (partition by value order by key),
|count(key) over (partition by value order by key),
|sum(key) over (partition by value order by key),
|ntile(key) over (partition by value order by key),
|ntile(key) over (partition by value order by key),
|row_number() over (partition by value order by key),
Expand All @@ -143,7 +179,11 @@ class HiveDataFrameWindowSuite extends QueryTest {
.preceding(1)
.and
.following(1))),
Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil)
sql(
"""SELECT
| avg(key) OVER
| (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 1 following)
| FROM window_table""".stripMargin).collect())
}

test("aggregation in a Range window") {
Expand All @@ -159,11 +199,15 @@ class HiveDataFrameWindowSuite extends QueryTest {
.preceding(1)
.and
.following(1))),
Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil)
sql(
"""SELECT
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following)
| FROM window_table""".stripMargin).collect())
}

test("Aggregate function in Row preceding Window") {
val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value")
val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "3"), (4, "3")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
Expand All @@ -172,22 +216,65 @@ class HiveDataFrameWindowSuite extends QueryTest {
partitionBy($"value")
.orderBy($"key")
.rows
.preceding(1))),
Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil)
.preceding(1)),
first("value").over(
partitionBy($"value")
.orderBy($"key")
.rows
.between
.preceding(2)
.and
.preceding(1))),
sql(
"""SELECT
| key,
| first_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS 1 preceding),
| first_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS between 2 preceding and 1 preceding)
| FROM window_table""".stripMargin).collect())
}

test("Aggregate function in Row following Window") {
val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value")
val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
last("value").over(
partitionBy($"value")
.orderBy($"key")
.rows
.between
.currentRow()
.and
.unboundedFollowing()),
last("value").over(
partitionBy($"value")
.orderBy($"key")
.rows
.between
.unboundedPreceding()
.and
.currentRow()),
last("value").over(
partitionBy($"value")
.orderBy($"key")
.rows
.between
.preceding(1)
.and
.following(1))),
Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil)
sql(
"""SELECT
| key,
| last_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS between current row and unbounded following),
| last_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row),
| last_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following)
| FROM window_table""".stripMargin).collect())
}

test("Multiple aggregate functions in row window") {
Expand Down Expand Up @@ -216,12 +303,15 @@ class HiveDataFrameWindowSuite extends QueryTest {
.preceding(2)
.and
.preceding(1))),
Row(1.0, 1.0, 1.0) ::
Row(1.0, 1.0, 1.0) ::
Row(1.0, 1.0, 1.0) ::
Row(2.0, 2.0, 2.0) ::
Row(2.0, 2.0, 2.0) ::
Row(3.0, 3.0, 3.0) :: Nil)
sql(
"""SELECT
| avg(key) OVER
| (partition by key ORDER BY value rows 1 preceding),
| avg(key) OVER
| (partition by key ORDER BY value rows between current row and current row),
| avg(key) OVER
| (partition by key ORDER BY value rows between 2 preceding and 1 preceding)
| FROM window_table""".stripMargin).collect())
}

test("Multiple aggregate functions in range window") {
Expand Down Expand Up @@ -268,11 +358,17 @@ class HiveDataFrameWindowSuite extends QueryTest {
.currentRow)
.as("avg_key3")
),
Row(1, false, 1.0, 1.0, 1.0) ::
Row(1, false, 1.0, 1.0, 1.0) ::
Row(2, true, 2.0, 2.0, 2.0) ::
Row(2, true, 2.0, 2.0, 2.0) ::
Row(2, true, 2.0, 2.0, 2.0) ::
Row(2, true, 2.0, 2.0, 2.0) :: Nil)
sql(
"""SELECT
| key,
| last_value(value) OVER
| (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2",
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 following),
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN current row and 1 following),
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row)
| FROM window_table""".stripMargin).collect())
}
}

0 comments on commit f3fd2d0

Please sign in to comment.