Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-10641][SQL] Add Skewness and Kurtosis Support #9003

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ object FunctionRegistry {
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
expression[Sum]("sum"),
expression[Variance]("variance"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
expression[Skewness]("skewness"),
expression[Kurtosis]("kurtosis"),

// string functions
expression[Ascii]("ascii"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ object HiveTypeCoercion {
case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType))
case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
case Variance(e @ StringType()) => Variance(Cast(e, DoubleType))
case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ package object dsl {
def stddev(e: Expression): Expression = Stddev(e)
def stddev_pop(e: Expression): Expression = StddevPop(e)
def stddev_samp(e: Expression): Expression = StddevSamp(e)
def variance(e: Expression): Expression = Variance(e)
def var_pop(e: Expression): Expression = VariancePop(e)
def var_samp(e: Expression): Expression = VarianceSamp(e)
def skewness(e: Expression): Expression = Skewness(e)
def kurtosis(e: Expression): Expression = Kurtosis(e)

implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,3 +930,330 @@ object HyperLogLogPlusPlus {
)
// scalastyle:on
}

/**
* A central moment is the expected value of a specified power of the deviation of a random
* variable from the mean. Central moments are often used to characterize the properties of about
* the shape of a distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also document the behavior for null and NaN values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not for the case when n = 0 or n =1 but when we have null or NaN in the values. Are we ignoring them or outputting NaN directly?

*
* This class implements online, one-pass algorithms for computing the central moments of a set of
* points.
*
* Returns `Double.NaN` when N = 0 or N = 1
* -third and fourth moments return `Double.NaN` when second moment is zero
*
* References:
* - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments."
* 2015. http://arxiv.org/abs/1510.04923
*
* @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
* Algorithms for calculating variance (Wikipedia)]]
*
* @param child to compute central moments of.
*/
abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable {

/**
* The central moment order to be computed.
*/
protected def momentOrder: Int

override def children: Seq[Expression] = Seq(child)

override def nullable: Boolean = false

override def dataType: DataType = DoubleType

// Expected input data type.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select avg(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

/**
* Size of aggregation buffer.
*/
private[this] val bufferSize = 5

override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i =>
AttributeReference(s"M$i", DoubleType)()
}

// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

// buffer offsets
private[this] val nOffset = mutableAggBufferOffset
private[this] val meanOffset = mutableAggBufferOffset + 1
private[this] val secondMomentOffset = mutableAggBufferOffset + 2
private[this] val thirdMomentOffset = mutableAggBufferOffset + 3
private[this] val fourthMomentOffset = mutableAggBufferOffset + 4

// frequently used values for online updates
private[this] var delta = 0.0
private[this] var deltaN = 0.0
private[this] var delta2 = 0.0
private[this] var deltaN2 = 0.0
private[this] var n = 0.0
private[this] var mean = 0.0
private[this] var m2 = 0.0
private[this] var m3 = 0.0
private[this] var m4 = 0.0

/**
* Initialize all moments to zero.
*/
override def initialize(buffer: MutableRow): Unit = {
for (aggIndex <- 0 until bufferSize) {
buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
}
}

/**
* Update the central moments buffer.
*/
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val v = Cast(child, DoubleType).eval(input)
if (v != null) {
val updateValue = v match {
case d: Double => d
}
n = buffer.getDouble(nOffset)
mean = buffer.getDouble(meanOffset)

n += 1.0
buffer.setDouble(nOffset, n)
delta = updateValue - mean
deltaN = delta / n
mean += deltaN
buffer.setDouble(meanOffset, mean)

if (momentOrder >= 2) {
m2 = buffer.getDouble(secondMomentOffset)
m2 += delta * (delta - deltaN)
buffer.setDouble(secondMomentOffset, m2)
}

if (momentOrder >= 3) {
delta2 = delta * delta
deltaN2 = deltaN * deltaN
m3 = buffer.getDouble(thirdMomentOffset)
m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2)
buffer.setDouble(thirdMomentOffset, m3)
}

if (momentOrder >= 4) {
m4 = buffer.getDouble(fourthMomentOffset)
m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 +
delta * (delta * delta2 - deltaN * deltaN2)
buffer.setDouble(fourthMomentOffset, m4)
}
}
}

/**
* Merge two central moment buffers.
*/
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
val n1 = buffer1.getDouble(nOffset)
val n2 = buffer2.getDouble(inputAggBufferOffset)
val mean1 = buffer1.getDouble(meanOffset)
val mean2 = buffer2.getDouble(inputAggBufferOffset + 1)

var secondMoment1 = 0.0
var secondMoment2 = 0.0

var thirdMoment1 = 0.0
var thirdMoment2 = 0.0

var fourthMoment1 = 0.0
var fourthMoment2 = 0.0

n = n1 + n2
buffer1.setDouble(nOffset, n)
delta = mean2 - mean1
deltaN = delta / n
mean = mean1 + deltaN * n
buffer1.setDouble(mutableAggBufferOffset + 1, mean)

// higher order moments computed according to:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
if (momentOrder >= 2) {
secondMoment1 = buffer1.getDouble(secondMomentOffset)
secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2)
m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2
buffer1.setDouble(secondMomentOffset, m2)
}

if (momentOrder >= 3) {
thirdMoment1 = buffer1.getDouble(thirdMomentOffset)
thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3)
m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 *
(n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1)
buffer1.setDouble(thirdMomentOffset, m3)
}

if (momentOrder >= 4) {
fourthMoment1 = buffer1.getDouble(fourthMomentOffset)
fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4)
m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 *
n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 *
(n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) +
4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1)
buffer1.setDouble(fourthMomentOffset, m4)
}
}

/**
* Compute aggregate statistic from sufficient moments.
* @param centralMoments Length `momentOrder + 1` array of central moments needed to
* compute the aggregate stat.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is useful to mention that whether they are divided by n or not.

*/
def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double

override final def eval(buffer: InternalRow): Any = {
val n = buffer.getDouble(nOffset)
val mean = buffer.getDouble(meanOffset)
val moments = Array.ofDim[Double](momentOrder + 1)
moments(0) = n
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought moments(0) and moments(1) are just placeholders, we should use the correct values moments(0) = 1.0 and moments(1) = 0.0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected.

moments(1) = mean
if (momentOrder >= 2) {
moments(2) = buffer.getDouble(secondMomentOffset)
}
if (momentOrder >= 3) {
moments(3) = buffer.getDouble(thirdMomentOffset)
}
if (momentOrder >= 4) {
moments(4) = buffer.getDouble(fourthMomentOffset)
}

getStatistic(n, mean, moments)
}
}

case class Variance(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "variance"

override protected val momentOrder = 2

override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")

if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0)
}
}

case class VarianceSamp(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "variance_samp"

override protected val momentOrder = 2

override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")

if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0)
}
}

case class VariancePop(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "variance_pop"

override protected val momentOrder = 2

override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}")

if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / n
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove || n == 1.0. Population variance should be 0.0 when n == 1.0, which is the same as moments(2) / n.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

}
}

case class Skewness(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "skewness"

override protected val momentOrder = 3

override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m3 = moments(3)
if (n == 0.0 || m2 == 0.0) {
Double.NaN
} else {
math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2)
}
}
}

case class Kurtosis(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "kurtosis"

override protected val momentOrder = 4

// NOTE: this is the formula for excess kurtosis, which is default for R and SciPy
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
val m2 = moments(2)
val m4 = moments(4)
if (n == 0.0 || m2 == 0.0) {
Double.NaN
} else {
n * m4 / (m2 * m2) - 3.0
}
}
}
Loading