-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10641][SQL] Add Skewness and Kurtosis Support #9003
Changes from 17 commits
bc8ab0c
cf52ed7
7ecf50e
579b9f2
230f66c
1c4c4d0
dc223bc
83fb682
d54fb0d
853922a
4a5350e
dba511b
44c1437
7baac9d
345463e
3ef2faa
fd3f4d6
cf8a14b
b86386a
3045e3b
ff363cc
f49ce5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -930,3 +930,330 @@ object HyperLogLogPlusPlus { | |
) | ||
// scalastyle:on | ||
} | ||
|
||
/** | ||
* A central moment is the expected value of a specified power of the deviation of a random | ||
* variable from the mean. Central moments are often used to characterize the properties of about | ||
* the shape of a distribution. | ||
* | ||
* This class implements online, one-pass algorithms for computing the central moments of a set of | ||
* points. | ||
* | ||
* Returns `Double.NaN` when N = 0 or N = 1 | ||
* -third and fourth moments return `Double.NaN` when second moment is zero | ||
* | ||
* References: | ||
* - Xiangrui Meng. "Simpler Online Updates for Arbitrary-Order Central Moments." | ||
* 2015. http://arxiv.org/abs/1510.04923 | ||
* | ||
* @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance | ||
* Algorithms for calculating variance (Wikipedia)]] | ||
* | ||
* @param child to compute central moments of. | ||
*/ | ||
abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate with Serializable { | ||
|
||
/** | ||
* The central moment order to be computed. | ||
*/ | ||
protected def momentOrder: Int | ||
|
||
override def children: Seq[Expression] = Seq(child) | ||
|
||
override def nullable: Boolean = false | ||
|
||
override def dataType: DataType = DoubleType | ||
|
||
// Expected input data type. | ||
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the | ||
// new version at planning time (after analysis phase). For now, NullType is added at here | ||
// to make it resolved when we have cases like `select avg(null)`. | ||
// We can use our analyzer to cast NullType to the default data type of the NumericType once | ||
// we remove the old aggregate functions. Then, we will not need NullType at here. | ||
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) | ||
|
||
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) | ||
|
||
/** | ||
* Size of aggregation buffer. | ||
*/ | ||
private[this] val bufferSize = 5 | ||
|
||
override val aggBufferAttributes: Seq[AttributeReference] = Seq.tabulate(bufferSize) { i => | ||
AttributeReference(s"M$i", DoubleType)() | ||
} | ||
|
||
// Note: although this simply copies aggBufferAttributes, this common code can not be placed | ||
// in the superclass because that will lead to initialization ordering issues. | ||
override val inputAggBufferAttributes: Seq[AttributeReference] = | ||
aggBufferAttributes.map(_.newInstance()) | ||
|
||
// buffer offsets | ||
private[this] val nOffset = mutableAggBufferOffset | ||
private[this] val meanOffset = mutableAggBufferOffset + 1 | ||
private[this] val secondMomentOffset = mutableAggBufferOffset + 2 | ||
private[this] val thirdMomentOffset = mutableAggBufferOffset + 3 | ||
private[this] val fourthMomentOffset = mutableAggBufferOffset + 4 | ||
|
||
// frequently used values for online updates | ||
private[this] var delta = 0.0 | ||
private[this] var deltaN = 0.0 | ||
private[this] var delta2 = 0.0 | ||
private[this] var deltaN2 = 0.0 | ||
private[this] var n = 0.0 | ||
private[this] var mean = 0.0 | ||
private[this] var m2 = 0.0 | ||
private[this] var m3 = 0.0 | ||
private[this] var m4 = 0.0 | ||
|
||
/** | ||
* Initialize all moments to zero. | ||
*/ | ||
override def initialize(buffer: MutableRow): Unit = { | ||
for (aggIndex <- 0 until bufferSize) { | ||
buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0) | ||
} | ||
} | ||
|
||
/** | ||
* Update the central moments buffer. | ||
*/ | ||
override def update(buffer: MutableRow, input: InternalRow): Unit = { | ||
val v = Cast(child, DoubleType).eval(input) | ||
if (v != null) { | ||
val updateValue = v match { | ||
case d: Double => d | ||
} | ||
n = buffer.getDouble(nOffset) | ||
mean = buffer.getDouble(meanOffset) | ||
|
||
n += 1.0 | ||
buffer.setDouble(nOffset, n) | ||
delta = updateValue - mean | ||
deltaN = delta / n | ||
mean += deltaN | ||
buffer.setDouble(meanOffset, mean) | ||
|
||
if (momentOrder >= 2) { | ||
m2 = buffer.getDouble(secondMomentOffset) | ||
m2 += delta * (delta - deltaN) | ||
buffer.setDouble(secondMomentOffset, m2) | ||
} | ||
|
||
if (momentOrder >= 3) { | ||
delta2 = delta * delta | ||
deltaN2 = deltaN * deltaN | ||
m3 = buffer.getDouble(thirdMomentOffset) | ||
m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2) | ||
buffer.setDouble(thirdMomentOffset, m3) | ||
} | ||
|
||
if (momentOrder >= 4) { | ||
m4 = buffer.getDouble(fourthMomentOffset) | ||
m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 + | ||
delta * (delta * delta2 - deltaN * deltaN2) | ||
buffer.setDouble(fourthMomentOffset, m4) | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Merge two central moment buffers. | ||
*/ | ||
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { | ||
val n1 = buffer1.getDouble(nOffset) | ||
val n2 = buffer2.getDouble(inputAggBufferOffset) | ||
val mean1 = buffer1.getDouble(meanOffset) | ||
val mean2 = buffer2.getDouble(inputAggBufferOffset + 1) | ||
|
||
var secondMoment1 = 0.0 | ||
var secondMoment2 = 0.0 | ||
|
||
var thirdMoment1 = 0.0 | ||
var thirdMoment2 = 0.0 | ||
|
||
var fourthMoment1 = 0.0 | ||
var fourthMoment2 = 0.0 | ||
|
||
n = n1 + n2 | ||
buffer1.setDouble(nOffset, n) | ||
delta = mean2 - mean1 | ||
deltaN = delta / n | ||
mean = mean1 + deltaN * n | ||
buffer1.setDouble(mutableAggBufferOffset + 1, mean) | ||
|
||
// higher order moments computed according to: | ||
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics | ||
if (momentOrder >= 2) { | ||
secondMoment1 = buffer1.getDouble(secondMomentOffset) | ||
secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2) | ||
m2 = secondMoment1 + secondMoment2 + delta * deltaN * n1 * n2 | ||
buffer1.setDouble(secondMomentOffset, m2) | ||
} | ||
|
||
if (momentOrder >= 3) { | ||
thirdMoment1 = buffer1.getDouble(thirdMomentOffset) | ||
thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3) | ||
m3 = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta * n1 * n2 * | ||
(n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * secondMoment1) | ||
buffer1.setDouble(thirdMomentOffset, m3) | ||
} | ||
|
||
if (momentOrder >= 4) { | ||
fourthMoment1 = buffer1.getDouble(fourthMomentOffset) | ||
fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4) | ||
m4 = fourthMoment1 + fourthMoment2 + deltaN * deltaN * deltaN * delta * n1 * | ||
n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 * | ||
(n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) + | ||
4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1) | ||
buffer1.setDouble(fourthMomentOffset, m4) | ||
} | ||
} | ||
|
||
/** | ||
* Compute aggregate statistic from sufficient moments. | ||
* @param centralMoments Length `momentOrder + 1` array of central moments needed to | ||
* compute the aggregate stat. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is useful to mention that whether they are divided by |
||
*/ | ||
def getStatistic(n: Double, mean: Double, centralMoments: Array[Double]): Double | ||
|
||
override final def eval(buffer: InternalRow): Any = { | ||
val n = buffer.getDouble(nOffset) | ||
val mean = buffer.getDouble(meanOffset) | ||
val moments = Array.ofDim[Double](momentOrder + 1) | ||
moments(0) = n | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thought There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Corrected. |
||
moments(1) = mean | ||
if (momentOrder >= 2) { | ||
moments(2) = buffer.getDouble(secondMomentOffset) | ||
} | ||
if (momentOrder >= 3) { | ||
moments(3) = buffer.getDouble(thirdMomentOffset) | ||
} | ||
if (momentOrder >= 4) { | ||
moments(4) = buffer.getDouble(fourthMomentOffset) | ||
} | ||
|
||
getStatistic(n, mean, moments) | ||
} | ||
} | ||
|
||
case class Variance(child: Expression, | ||
mutableAggBufferOffset: Int = 0, | ||
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override def prettyName: String = "variance" | ||
|
||
override protected val momentOrder = 2 | ||
|
||
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
require(moments.length == momentOrder + 1, | ||
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") | ||
|
||
if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) | ||
} | ||
} | ||
|
||
case class VarianceSamp(child: Expression, | ||
mutableAggBufferOffset: Int = 0, | ||
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override def prettyName: String = "variance_samp" | ||
|
||
override protected val momentOrder = 2 | ||
|
||
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
require(moments.length == momentOrder + 1, | ||
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") | ||
|
||
if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / (n - 1.0) | ||
} | ||
} | ||
|
||
case class VariancePop(child: Expression, | ||
mutableAggBufferOffset: Int = 0, | ||
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override def prettyName: String = "variance_pop" | ||
|
||
override protected val momentOrder = 2 | ||
|
||
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
require(moments.length == momentOrder + 1, | ||
s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") | ||
|
||
if (n == 0.0 || n == 1.0) Double.NaN else moments(2) / n | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
} | ||
} | ||
|
||
case class Skewness(child: Expression, | ||
mutableAggBufferOffset: Int = 0, | ||
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override def prettyName: String = "skewness" | ||
|
||
override protected val momentOrder = 3 | ||
|
||
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
require(moments.length == momentOrder + 1, | ||
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") | ||
val m2 = moments(2) | ||
val m3 = moments(3) | ||
if (n == 0.0 || m2 == 0.0) { | ||
Double.NaN | ||
} else { | ||
math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) | ||
} | ||
} | ||
} | ||
|
||
case class Kurtosis(child: Expression, | ||
mutableAggBufferOffset: Int = 0, | ||
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { | ||
|
||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | ||
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | ||
|
||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override def prettyName: String = "kurtosis" | ||
|
||
override protected val momentOrder = 4 | ||
|
||
// NOTE: this is the formula for excess kurtosis, which is default for R and SciPy | ||
override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { | ||
require(moments.length == momentOrder + 1, | ||
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") | ||
val m2 = moments(2) | ||
val m4 = moments(4) | ||
if (n == 0.0 || m2 == 0.0) { | ||
Double.NaN | ||
} else { | ||
n * m4 / (m2 * m2) - 3.0 | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also document the behavior for
null
andNaN
values.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not for the case when
n = 0
orn =1
but when we havenull
orNaN
in the values. Are we ignoring them or outputtingNaN
directly?