Skip to content

Commit

Permalink
[SPARK-19634][SQL][ML][FOLLOW-UP] Improve interface of dataframe vect…
Browse files Browse the repository at this point in the history
…orized summarizer

## What changes were proposed in this pull request?

Make several improvements in dataframe vectorized summarizer.

1. Make the summarizer return `Vector` type for all metrics (except "count").
It will return "WrappedArray" type before which won't be very convenient.

2. Make `MetricsAggregate` inherit `ImplicitCastInputTypes` trait. So it can check and implicitly cast input values.

3. Add "weight" parameter for all single metric method.

4. Update doc and improve the example code in doc.

5. Simplified test cases.

## How was this patch tested?

Test added and simplified.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19156 from WeichenXu123/improve_vec_summarizer.
  • Loading branch information
WeichenXu123 authored and yanboliang committed Dec 21, 2017
1 parent 9c289a5 commit d3ae3e1
Show file tree
Hide file tree
Showing 3 changed files with 341 additions and 213 deletions.
128 changes: 85 additions & 43 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
Expand All @@ -41,7 +41,7 @@ sealed abstract class SummaryBuilder {
/**
* Returns an aggregate object that contains the summary of the column with the requested metrics.
* @param featuresCol a column that contains features Vector object.
* @param weightCol a column that contains weight value.
* @param weightCol a column that contains weight value. Default weight is 1.0.
* @return an aggregate column that contains the statistics. The exact content of this
* structure is determined during the creation of the builder.
*/
Expand All @@ -50,6 +50,7 @@ sealed abstract class SummaryBuilder {

@Since("2.3.0")
def summary(featuresCol: Column): Column = summary(featuresCol, lit(1.0))

}

/**
Expand All @@ -60,15 +61,18 @@ sealed abstract class SummaryBuilder {
* This class lets users pick the statistics they would like to extract for a given column. Here is
* an example in Scala:
* {{{
* val dataframe = ... // Some dataframe containing a feature column
* val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features"))
* val Row(Row(min_, max_)) = allStats.first()
* import org.apache.spark.ml.linalg._
* import org.apache.spark.sql.Row
* val dataframe = ... // Some dataframe containing a feature column and a weight column
* val multiStatsDF = dataframe.select(
* Summarizer.metrics("min", "max", "count").summary($"features", $"weight")
* val Row(Row(minVec, maxVec, count)) = multiStatsDF.first()
* }}}
*
* If one wants to get a single metric, shortcuts are also available:
* {{{
* val meanDF = dataframe.select(Summarizer.mean($"features"))
* val Row(mean_) = meanDF.first()
* val Row(meanVec) = meanDF.first()
* }}}
*
* Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
Expand All @@ -94,46 +98,87 @@ object Summarizer extends Logging {
* - min: the minimum for each coefficient.
* - normL2: the Euclidian norm for each coefficient.
* - normL1: the L1 norm of each coefficient (sum of the absolute values).
* @param firstMetric the metric being provided
* @param metrics additional metrics that can be provided.
* @param metrics metrics that can be provided.
* @return a builder.
* @throws IllegalArgumentException if one of the metric names is not understood.
*
* Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
* interface.
*/
@Since("2.3.0")
def metrics(firstMetric: String, metrics: String*): SummaryBuilder = {
val (typedMetrics, computeMetrics) = getRelevantMetrics(Seq(firstMetric) ++ metrics)
@scala.annotation.varargs
def metrics(metrics: String*): SummaryBuilder = {
require(metrics.size >= 1, "Should include at least one metric")
val (typedMetrics, computeMetrics) = getRelevantMetrics(metrics)
new SummaryBuilderImpl(typedMetrics, computeMetrics)
}

@Since("2.3.0")
def mean(col: Column): Column = getSingleMetric(col, "mean")
def mean(col: Column, weightCol: Column): Column = {
getSingleMetric(col, weightCol, "mean")
}

@Since("2.3.0")
def mean(col: Column): Column = mean(col, lit(1.0))

@Since("2.3.0")
def variance(col: Column, weightCol: Column): Column = {
getSingleMetric(col, weightCol, "variance")
}

@Since("2.3.0")
def variance(col: Column): Column = variance(col, lit(1.0))

@Since("2.3.0")
def count(col: Column, weightCol: Column): Column = {
getSingleMetric(col, weightCol, "count")
}

@Since("2.3.0")
def count(col: Column): Column = count(col, lit(1.0))

@Since("2.3.0")
def variance(col: Column): Column = getSingleMetric(col, "variance")
def numNonZeros(col: Column, weightCol: Column): Column = {
getSingleMetric(col, weightCol, "numNonZeros")
}

@Since("2.3.0")
def numNonZeros(col: Column): Column = numNonZeros(col, lit(1.0))

@Since("2.3.0")
def max(col: Column, weightCol: Column): Column = {
getSingleMetric(col, weightCol, "max")
}

@Since("2.3.0")
def max(col: Column): Column = max(col, lit(1.0))

@Since("2.3.0")
def count(col: Column): Column = getSingleMetric(col, "count")
def min(col: Column, weightCol: Column): Column = {
getSingleMetric(col, weightCol, "min")
}

@Since("2.3.0")
def numNonZeros(col: Column): Column = getSingleMetric(col, "numNonZeros")
def min(col: Column): Column = min(col, lit(1.0))

@Since("2.3.0")
def max(col: Column): Column = getSingleMetric(col, "max")
def normL1(col: Column, weightCol: Column): Column = {
getSingleMetric(col, weightCol, "normL1")
}

@Since("2.3.0")
def min(col: Column): Column = getSingleMetric(col, "min")
def normL1(col: Column): Column = normL1(col, lit(1.0))

@Since("2.3.0")
def normL1(col: Column): Column = getSingleMetric(col, "normL1")
def normL2(col: Column, weightCol: Column): Column = {
getSingleMetric(col, weightCol, "normL2")
}

@Since("2.3.0")
def normL2(col: Column): Column = getSingleMetric(col, "normL2")
def normL2(col: Column): Column = normL2(col, lit(1.0))

private def getSingleMetric(col: Column, metric: String): Column = {
val c1 = metrics(metric).summary(col)
private def getSingleMetric(col: Column, weightCol: Column, metric: String): Column = {
val c1 = metrics(metric).summary(col, weightCol)
c1.getField(metric).as(s"$metric($col)")
}
}
Expand Down Expand Up @@ -187,8 +232,7 @@ private[ml] object SummaryBuilderImpl extends Logging {
StructType(fields)
}

private val arrayDType = ArrayType(DoubleType, containsNull = false)
private val arrayLType = ArrayType(LongType, containsNull = false)
private val vectorUDT = new VectorUDT

/**
* All the metrics that can be currently computed by Spark for vectors.
Expand All @@ -197,14 +241,14 @@ private[ml] object SummaryBuilderImpl extends Logging {
* metrics that need to de computed internally to get the final result.
*/
private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq(
("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)),
("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
("mean", Mean, vectorUDT, Seq(ComputeMean, ComputeWeightSum)),
("variance", Variance, vectorUDT, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
("count", Count, LongType, Seq()),
("numNonZeros", NumNonZeros, arrayLType, Seq(ComputeNNZ)),
("max", Max, arrayDType, Seq(ComputeMax, ComputeNNZ)),
("min", Min, arrayDType, Seq(ComputeMin, ComputeNNZ)),
("normL2", NormL2, arrayDType, Seq(ComputeM2)),
("normL1", NormL1, arrayDType, Seq(ComputeL1))
("numNonZeros", NumNonZeros, vectorUDT, Seq(ComputeNNZ)),
("max", Max, vectorUDT, Seq(ComputeMax, ComputeNNZ)),
("min", Min, vectorUDT, Seq(ComputeMin, ComputeNNZ)),
("normL2", NormL2, vectorUDT, Seq(ComputeM2)),
("normL1", NormL1, vectorUDT, Seq(ComputeL1))
)

/**
Expand Down Expand Up @@ -527,27 +571,28 @@ private[ml] object SummaryBuilderImpl extends Logging {
weightExpr: Expression,
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int)
extends TypedImperativeAggregate[SummarizerBuffer] {
extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes {

override def eval(state: SummarizerBuffer): InternalRow = {
override def eval(state: SummarizerBuffer): Any = {
val metrics = requestedMetrics.map {
case Mean => UnsafeArrayData.fromPrimitiveArray(state.mean.toArray)
case Variance => UnsafeArrayData.fromPrimitiveArray(state.variance.toArray)
case Mean => vectorUDT.serialize(state.mean)
case Variance => vectorUDT.serialize(state.variance)
case Count => state.count
case NumNonZeros => UnsafeArrayData.fromPrimitiveArray(
state.numNonzeros.toArray.map(_.toLong))
case Max => UnsafeArrayData.fromPrimitiveArray(state.max.toArray)
case Min => UnsafeArrayData.fromPrimitiveArray(state.min.toArray)
case NormL2 => UnsafeArrayData.fromPrimitiveArray(state.normL2.toArray)
case NormL1 => UnsafeArrayData.fromPrimitiveArray(state.normL1.toArray)
case NumNonZeros => vectorUDT.serialize(state.numNonzeros)
case Max => vectorUDT.serialize(state.max)
case Min => vectorUDT.serialize(state.min)
case NormL2 => vectorUDT.serialize(state.normL2)
case NormL1 => vectorUDT.serialize(state.normL1)
}
InternalRow.apply(metrics: _*)
}

override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil

override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil

override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
val features = udt.deserialize(featuresExpr.eval(row))
val features = vectorUDT.deserialize(featuresExpr.eval(row))
val weight = weightExpr.eval(row).asInstanceOf[Double]
state.add(features, weight)
state
Expand Down Expand Up @@ -591,7 +636,4 @@ private[ml] object SummaryBuilderImpl extends Logging {
override def prettyName: String = "aggregate_metrics"

}

private[this] val udt = new VectorUDT

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.stat;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertArrayEquals;

import org.apache.spark.SharedSparkSession;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Dataset;
import static org.apache.spark.sql.functions.col;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;

public class JavaSummarizerSuite extends SharedSparkSession {

private transient Dataset<Row> dataset;

@Override
public void setUp() throws IOException {
super.setUp();
List<LabeledPoint> points = new ArrayList<LabeledPoint>();
points.add(new LabeledPoint(0.0, Vectors.dense(1.0, 2.0)));
points.add(new LabeledPoint(0.0, Vectors.dense(3.0, 4.0)));

dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
}

@Test
public void testSummarizer() {
dataset.select(col("features"));
Row result = dataset
.select(Summarizer.metrics("mean", "max", "count").summary(col("features")))
.first().getStruct(0);
Vector meanVec = result.getAs("mean");
Vector maxVec = result.getAs("max");
long count = result.getAs("count");

assertEquals(2L, count);
assertArrayEquals(new double[]{2.0, 3.0}, meanVec.toArray(), 0.0);
assertArrayEquals(new double[]{3.0, 4.0}, maxVec.toArray(), 0.0);
}
}
Loading

0 comments on commit d3ae3e1

Please sign in to comment.