Skip to content

Commit

Permalink
[SPARK-11420] Updating Stddev support via Imperative Aggregate
Browse files Browse the repository at this point in the history
switched stddev support from DeclarativeAggregate to ImperativeAggregate.

Author: JihongMa <linlin200605@gmail.com>

Closes apache#9380 from JihongMA/SPARK-11420.
  • Loading branch information
JihongMA authored and dskrvk committed Nov 13, 2015
1 parent 34d745a commit d721276
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 115 deletions.
4 changes: 2 additions & 2 deletions R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ test_that("group by, agg functions", {
df3 <- agg(gd, age = "stddev")
expect_is(df3, "DataFrame")
df3_local <- collect(df3)
expect_equal(0, df3_local[df3_local$name == "Andy",][1, 2])
expect_true(is.nan(df3_local[df3_local$name == "Andy",][1, 2]))

df4 <- agg(gd, sumAge = sum(df$age))
expect_is(df4, "DataFrame")
Expand Down Expand Up @@ -1038,7 +1038,7 @@ test_that("group by, agg functions", {
df7 <- agg(gd2, value = "stddev")
df7_local <- collect(df7)
expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6)
expect_equal(0, df7_local[df7_local$name == "ID2",][1, 2])
expect_true(is.nan(df7_local[df7_local$name == "ID2",][1, 2]))

mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}",
"{\"name\":\"Andy\", \"age\":30}",
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def describe(self, *cols):
+-------+------------------+-----+
| count| 2| 2|
| mean| 3.5| null|
| stddev|2.1213203435596424| null|
| stddev|2.1213203435596424| NaN|
| min| 2|Alice|
| max| 5| Bob|
+-------+------------------+-----+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,10 @@ object HiveTypeCoercion {

case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ case class Kurtosis(child: Expression,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m4 = moments(4)

if (n == 0.0 || m2 == 0.0) {
Double.NaN
} else {
}
else {
n * m4 / (m2 * m2) - 3.0
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ case class Skewness(child: Expression,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m3 = moments(3)

if (n == 0.0 || m2 == 0.0) {
Double.NaN
} else {
}
else {
math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,117 +17,55 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

case class StddevSamp(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {

// Compute the population standard deviation of a column
case class StddevPop(child: Expression) extends StddevAgg(child) {
override def isSample: Boolean = false
override def prettyName: String = "stddev_pop"
}


// Compute the sample standard deviation of a column
case class StddevSamp(child: Expression) extends StddevAgg(child) {
override def isSample: Boolean = true
override def prettyName: String = "stddev_samp"
}


// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)

def isSample: Boolean
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def children: Seq[Expression] = child :: Nil
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def nullable: Boolean = true

override def dataType: DataType = resultType

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def prettyName: String = "stddev_samp"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
override protected val momentOrder = 2

private lazy val resultType = DoubleType
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")

private lazy val count = AttributeReference("count", resultType)()
private lazy val avg = AttributeReference("avg", resultType)()
private lazy val mk = AttributeReference("mk", resultType)()
if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0))
}
}

override lazy val aggBufferAttributes = count :: avg :: mk :: Nil
case class StddevPop(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {

override lazy val initialValues: Seq[Expression] = Seq(
/* count = */ Cast(Literal(0), resultType),
/* avg = */ Cast(Literal(0), resultType),
/* mk = */ Cast(Literal(0), resultType)
)
def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)

override lazy val updateExpressions: Seq[Expression] = {
val value = Cast(child, resultType)
val newCount = count + Cast(Literal(1), resultType)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

// update average
// avg = avg + (value - avg)/count
val newAvg = avg + (value - avg) / newCount
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

// update sum ofference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg)
val newMk = mk + (value - avg) * (value - newAvg)
override def prettyName: String = "stddev_pop"

Seq(
/* count = */ If(IsNull(child), count, newCount),
/* avg = */ If(IsNull(child), avg, newAvg),
/* mk = */ If(IsNull(child), mk, newMk)
)
}
override protected val momentOrder = 2

override lazy val mergeExpressions: Seq[Expression] = {

// count merge
val newCount = count.left + count.right

// average merge
val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount

// update sum of square differences
val newMk = {
val avgDelta = avg.right - avg.left
val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount
mk.left + mk.right + mkDelta
}

Seq(
/* count = */ If(IsNull(count.left), count.right,
If(IsNull(count.right), count.left, newCount)),
/* avg = */ If(IsNull(avg.left), avg.right,
If(IsNull(avg.right), avg.left, newAvg)),
/* mk = */ If(IsNull(mk.left), mk.right,
If(IsNull(mk.right), mk.left, newMk))
)
}
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")

override lazy val evaluateExpression: Expression = {
// when count == 0, return null
// when count == 1, return 0
// when count >1
// stddev_samp = sqrt (mk/(count -1))
// stddev_pop = sqrt (mk/count)
val varCol =
if (isSample) {
mk / Cast(count - Cast(Literal(1), resultType), resultType)
} else {
mk / count
}

If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
Cast(Sqrt(varCol), resultType)))
if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ object functions extends LegacyFunctions {
def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }

/**
* Aggregate function: returns the unbiased sample standard deviation of
* Aggregate function: returns the sample standard deviation of
* the expression in a group.
*
* @group agg_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}

test("stddev") {
val testData2ADev = math.sqrt(4 / 5.0)
val testData2ADev = math.sqrt(4.0 / 5.0)
checkAnswer(
testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
Expand All @@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
Row(null, null, null))
Row(Double.NaN, Double.NaN, Double.NaN))
}

test("zero sum") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val emptyDescribeResult = Seq(
Row("count", "0", "0"),
Row("mean", null, null),
Row("stddev", null, null),
Row("stddev", "NaN", "NaN"),
Row("min", null, null),
Row("max", null, null))

Expand Down
11 changes: 2 additions & 9 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
testCodeGen(
"SELECT min(key) FROM testData3x",
Row(1) :: Nil)
// STDDEV
testCodeGen(
"SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a",
(1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25))))
testCodeGen(
"SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2",
Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil)
// Some combinations.
testCodeGen(
"""
Expand All @@ -341,8 +334,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(100, 1, 50.5, 300, 100) :: Nil)
// Aggregate with Code generation handling all null values
testCodeGen(
"SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData",
Row(null, null, null, 0) :: Nil)
"SELECT sum('a'), avg('a'), count(null) FROM testData",
Row(null, null, 0) :: Nil)
} finally {
sqlContext.dropTempTable("testData3x")
}
Expand Down

0 comments on commit d721276

Please sign in to comment.