Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-11420 Updating Stddev support via Imperative Aggregate #9380

Closed
wants to merge 14 commits into from
4 changes: 2 additions & 2 deletions R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ test_that("group by, agg functions", {
df3 <- agg(gd, age = "stddev")
expect_is(df3, "DataFrame")
df3_local <- collect(df3)
expect_equal(0, df3_local[df3_local$name == "Andy",][1, 2])
expect_true(is.nan(df3_local[df3_local$name == "Andy",][1, 2]))

df4 <- agg(gd, sumAge = sum(df$age))
expect_is(df4, "DataFrame")
Expand Down Expand Up @@ -1038,7 +1038,7 @@ test_that("group by, agg functions", {
df7 <- agg(gd2, value = "stddev")
df7_local <- collect(df7)
expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6)
expect_equal(0, df7_local[df7_local$name == "ID2",][1, 2])
expect_true(is.nan(df7_local[df7_local$name == "ID2",][1, 2]))

mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}",
"{\"name\":\"Andy\", \"age\":30}",
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def describe(self, *cols):
+-------+------------------+-----+
| count| 2| 2|
| mean| 3.5| null|
| stddev|2.1213203435596424| null|
| stddev|2.1213203435596424| NaN|
| min| 2|Alice|
| max| 5| Bob|
+-------+------------------+-----+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,10 @@ object HiveTypeCoercion {

case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset)
case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ case class Kurtosis(child: Expression,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m4 = moments(4)

if (n == 0.0 || m2 == 0.0) {
Double.NaN
} else {
}
else {
n * m4 / (m2 * m2) - 3.0
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ case class Skewness(child: Expression,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m3 = moments(3)

if (n == 0.0 || m2 == 0.0) {
Double.NaN
} else {
}
else {
math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,117 +17,55 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

case class StddevSamp(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {

// Compute the population standard deviation of a column
case class StddevPop(child: Expression) extends StddevAgg(child) {
override def isSample: Boolean = false
override def prettyName: String = "stddev_pop"
}


// Compute the sample standard deviation of a column
case class StddevSamp(child: Expression) extends StddevAgg(child) {
override def isSample: Boolean = true
override def prettyName: String = "stddev_samp"
}


// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)

def isSample: Boolean
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def children: Seq[Expression] = child :: Nil
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def nullable: Boolean = true

override def dataType: DataType = resultType

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def prettyName: String = "stddev_samp"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
override protected val momentOrder = 2

private lazy val resultType = DoubleType
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")

private lazy val count = AttributeReference("count", resultType)()
private lazy val avg = AttributeReference("avg", resultType)()
private lazy val mk = AttributeReference("mk", resultType)()
if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0))
}
}

override lazy val aggBufferAttributes = count :: avg :: mk :: Nil
case class StddevPop(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends CentralMomentAgg(child) {

override lazy val initialValues: Seq[Expression] = Seq(
/* count = */ Cast(Literal(0), resultType),
/* avg = */ Cast(Literal(0), resultType),
/* mk = */ Cast(Literal(0), resultType)
)
def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)

override lazy val updateExpressions: Seq[Expression] = {
val value = Cast(child, resultType)
val newCount = count + Cast(Literal(1), resultType)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

// update average
// avg = avg + (value - avg)/count
val newAvg = avg + (value - avg) / newCount
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

// update sum ofference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg)
val newMk = mk + (value - avg) * (value - newAvg)
override def prettyName: String = "stddev_pop"

Seq(
/* count = */ If(IsNull(child), count, newCount),
/* avg = */ If(IsNull(child), avg, newAvg),
/* mk = */ If(IsNull(child), mk, newMk)
)
}
override protected val momentOrder = 2

override lazy val mergeExpressions: Seq[Expression] = {

// count merge
val newCount = count.left + count.right

// average merge
val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount

// update sum of square differences
val newMk = {
val avgDelta = avg.right - avg.left
val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount
mk.left + mk.right + mkDelta
}

Seq(
/* count = */ If(IsNull(count.left), count.right,
If(IsNull(count.right), count.left, newCount)),
/* avg = */ If(IsNull(avg.left), avg.right,
If(IsNull(avg.right), avg.left, newAvg)),
/* mk = */ If(IsNull(mk.left), mk.right,
If(IsNull(mk.right), mk.left, newMk))
)
}
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")

override lazy val evaluateExpression: Expression = {
// when count == 0, return null
// when count == 1, return 0
// when count >1
// stddev_samp = sqrt (mk/(count -1))
// stddev_pop = sqrt (mk/count)
val varCol =
if (isSample) {
mk / Cast(count - Cast(Literal(1), resultType), resultType)
} else {
mk / count
}

If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
Cast(Sqrt(varCol), resultType)))
if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ object functions extends LegacyFunctions {
def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) }

/**
* Aggregate function: returns the unbiased sample standard deviation of
* Aggregate function: returns the sample standard deviation of
* the expression in a group.
*
* @group agg_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}

test("stddev") {
val testData2ADev = math.sqrt(4 / 5.0)
val testData2ADev = math.sqrt(4.0 / 5.0)
checkAnswer(
testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev))
Expand All @@ -205,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)),
Row(null, null, null))
Row(Double.NaN, Double.NaN, Double.NaN))
}

test("zero sum") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val emptyDescribeResult = Seq(
Row("count", "0", "0"),
Row("mean", null, null),
Row("stddev", null, null),
Row("stddev", "NaN", "NaN"),
Row("min", null, null),
Row("max", null, null))

Expand Down
11 changes: 2 additions & 9 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
testCodeGen(
"SELECT min(key) FROM testData3x",
Row(1) :: Nil)
// STDDEV
testCodeGen(
"SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a",
(1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25))))
testCodeGen(
"SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2",
Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil)
// Some combinations.
testCodeGen(
"""
Expand All @@ -343,8 +336,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Row(100, 1, 50.5, 300, 100) :: Nil)
// Aggregate with Code generation handling all null values
testCodeGen(
"SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData",
Row(null, null, null, 0) :: Nil)
"SELECT sum('a'), avg('a'), count(null) FROM testData",
Row(null, null, 0) :: Nil)
} finally {
sqlContext.dropTempTable("testData3x")
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
Expand Down