Skip to content

Commit

Permalink
Add additional analystcs functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed May 21, 2015
1 parent 57e3bc0 commit 6847825
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,6 @@ class WindowFunctionDefinition protected[sql](
case wf: WindowFunction => WindowExpression(
wf,
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case aggr: AggregateExpression =>
throw new UnsupportedOperationException(
"""Only support Aggregate Functions:
| avg, sum, count, first, last, min, max for now""".stripMargin)
case x =>
throw new UnsupportedOperationException(s"We don't support $x in window operation.")
}
Expand Down
82 changes: 82 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,88 @@ object functions {
new WindowFunctionDefinition().orderBy(cols: _*)
}

/**
* NTILE for specified expression.
* NTILE allows easy calculation of tertiles, quartiles, deciles and other
* common summary statistics. This function divides an ordered partition into a specified
* number of groups called buckets and assigns a bucket number to each row in the partition.
*
* @group window_funcs
*/
def ntile(e: Column): Column = {
UnresolvedWindowFunction("ntile", e.expr :: Nil)
}

/**
* NTILE for specified column.
* NTILE allows easy calculation of tertiles, quartiles, deciles and other
* common summary statistics. This function divides an ordered partition into a specified
* number of groups called buckets and assigns a bucket number to each row in the partition.
*
* @group window_funcs
*/
def ntile(columnName: String): Column = {
ntile(Column(columnName))
}

/**
* Assigns a unique number (sequentially, starting from 1, as defined by ORDER BY) to each
* row within the partition.
*
* @group window_funcs
*/
def rowNumber(): Column = {
UnresolvedWindowFunction("row_number", Nil)
}

/**
* The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking
* sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK
* and had three people tie for second place, you would say that all three were in second
* place and that the next person came in third.
*
* @group window_funcs
*/
def denseRank(): Column = {
UnresolvedWindowFunction("dense_rank", Nil)
}

/**
* The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking
* sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK
* and had three people tie for second place, you would say that all three were in second
* place and that the next person came in third.
*
* @group window_funcs
*/
def rank(): Column = {
UnresolvedWindowFunction("rank", Nil)
}

/**
* CUME_DIST (defined as the inverse of percentile in some statistical books) computes
* the position of a specified value relative to a set of values.
* To compute the CUME_DIST of a value x in a set S of size N, you use the formula:
* CUME_DIST(x) = number of values in S coming before and including x in the specified order / N
*
* @group window_funcs
*/
def cumeDist(): Column = {
UnresolvedWindowFunction("cume_dist", Nil)
}

/**
* PERCENT_RANK is similar to CUME_DIST, but it uses rank values rather than row counts
* in its numerator.
* The formula:
* (rank of row in its partition - 1) / (number of rows in the partition - 1)
*
* @group window_funcs
*/
def percentRank(): Column = {
UnresolvedWindowFunction("percent_rank", Nil)
}

//////////////////////////////////////////////////////////////////////////////////////////////
// Non-aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

package org.apache.spark.sql.hive

import org.apache.spark.sql.{AnalysisException, Row, QueryTest}
import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._

class HiveDataFrameWindowSuite extends QueryTest {
Expand Down Expand Up @@ -59,7 +60,7 @@ class HiveDataFrameWindowSuite extends QueryTest {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
lead("value").over(
lag("value").over(
partitionBy($"key")
.orderBy($"value"))),
Row("1") :: Row("2") :: Row(null) :: Row(null) :: Nil)
Expand All @@ -82,15 +83,56 @@ class HiveDataFrameWindowSuite extends QueryTest {
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
lead("value", 2, "n/a").over(
lag("value", 2, "n/a").over(
partitionBy($"key")
.orderBy($"value"))),
Row("1") :: Row("1") :: Row("2") :: Row("n/a")
:: Row("n/a") :: Row("n/a") :: Row("n/a") :: Nil)
}

test("rank functions in unspecific window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
ntile("key").over(
partitionBy("value")
.orderBy("key")),
ntile($"key").over(
partitionBy("value")
.orderBy("key")),
rowNumber().over(
partitionBy("value")
.orderBy("key")),
denseRank().over(
partitionBy("value")
.orderBy("key")),
rank().over(
partitionBy("value")
.orderBy("key")),
cumeDist().over(
partitionBy("value")
.orderBy("key")),
percentRank().over(
partitionBy("value")
.orderBy("key"))),
sql(
s"""SELECT
|key,
|ntile(key) over (partition by value order by key),
|ntile(key) over (partition by value order by key),
|row_number() over (partition by value order by key),
|dense_rank() over (partition by value order by key),
|rank() over (partition by value order by key),
|cume_dist() over (partition by value order by key),
|percent_rank() over (partition by value order by key)
|FROM window_table""".stripMargin).collect)
}

test("aggregation in a row window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
avg("key").over(
Expand All @@ -106,6 +148,7 @@ class HiveDataFrameWindowSuite extends QueryTest {

test("aggregation in a Range window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
avg("key").over(
Expand All @@ -119,68 +162,9 @@ class HiveDataFrameWindowSuite extends QueryTest {
Row(1.0) :: Row(1.0) :: Row(2.0) :: Row(2.0) :: Nil)
}

test("multiple aggregate function in window") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
checkAnswer(
df.select(
avg("key").over(
partitionBy($"value")
.orderBy($"key")
.rows
.preceding(1)),
sum("key").over(
partitionBy($"value")
.orderBy($"key")
.range
.between
.preceding(1)
.and
.following(1))),
Row(1.0, 2) :: Row(1.0, 2) :: Row(2.0, 4) :: Row(2.0, 4) :: Nil)
}

test("Window function in Unspecified Window") {
val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value")

checkAnswer(
df.select(
$"key",
first("value").over(
partitionBy($"key"))),
Row(1, "1") :: Row(2, "2") :: Row(2, "2") :: Nil)
}

test("Window function in Unspecified Window #2") {
val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value")

checkAnswer(
df.select(
$"key",
first("value").over(
partitionBy($"key")
.orderBy($"value"))),
Row(1, "1") :: Row(2, "2") :: Row(2, "2") :: Nil)
}

test("Aggregate function in Range Window") {
val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value")

checkAnswer(
df.select(
$"key",
first("value").over(
partitionBy($"value")
.orderBy($"key")
.range
.between
.preceding(1)
.and
.following(1))),
Row(1, "1") :: Row(2, "2") :: Row(2, "3") :: Nil)
}

test("Aggregate function in Row preceding Window") {
val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
Expand All @@ -194,6 +178,7 @@ class HiveDataFrameWindowSuite extends QueryTest {

test("Aggregate function in Row following Window") {
val df = Seq((1, "1"), (2, "2"), (2, "3")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
Expand All @@ -207,6 +192,7 @@ class HiveDataFrameWindowSuite extends QueryTest {

test("Multiple aggregate functions in row window") {
val df = Seq((1, "1"), (1, "2"), (3, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
avg("key").over(
Expand Down Expand Up @@ -240,6 +226,7 @@ class HiveDataFrameWindowSuite extends QueryTest {

test("Multiple aggregate functions in range window") {
val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
Expand Down

0 comments on commit 6847825

Please sign in to comment.