diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e98036a970d44..2f8ab3f43586d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -363,43 +363,68 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p + case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) + | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 - val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => - def ifExpr(expr: Expression) = { - If(EqualTo(pivotColumn, value), expr, Literal(null)) + def outputName(value: Literal, aggregate: Expression): String = { + if (singleAgg) value.toString else value + "_" + aggregate.sql + } + if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { + // Since evaluating |pivotValues| if statements for each input row can get slow this is an + // alternate plan that instead uses two steps of aggregation. + val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) + val namedPivotCol = pivotColumn match { + case n: NamedExpression => n + case _ => Alias(pivotColumn, "__pivot_col")() + } + val bigGroup = groupByExprs :+ namedPivotCol + val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) + val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) + val pivotAggs = namedAggExps.map { a => + Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) + .toAggregateExpression() + , "__pivot_" + a.sql)() + } + val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg) + val pivotAggAttribute = pivotAggs.map(_.toAttribute) + val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => + aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => + Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() + } } - aggregates.map { aggregate => - val filteredAggregate = aggregate.transformDown { - // Assumption is the aggregate function ignores nulls. This is true for all current - // AggregateFunction's with the exception of First and Last in their default mode - // (which we handle) and possibly some Hive UDAF's. - case First(expr, _) => - First(ifExpr(expr), Literal(true)) - case Last(expr, _) => - Last(ifExpr(expr), Literal(true)) - case a: AggregateFunction => - a.withNewChildren(a.children.map(ifExpr)) - }.transform { - // We are duplicating aggregates that are now computing a different value for each - // pivot value. - // TODO: Don't construct the physical container until after analysis. - case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) + Project(groupByExprs ++ pivotOutputs, secondAgg) + } else { + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => + def ifExpr(expr: Expression) = { + If(EqualTo(pivotColumn, value), expr, Literal(null)) } - if (filteredAggregate.fastEquals(aggregate)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$aggregate'") + aggregates.map { aggregate => + val filteredAggregate = aggregate.transformDown { + // Assumption is the aggregate function ignores nulls. This is true for all current + // AggregateFunction's with the exception of First and Last in their default mode + // (which we handle) and possibly some Hive UDAF's. + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) + case a: AggregateFunction => + a.withNewChildren(a.children.map(ifExpr)) + }.transform { + // We are duplicating aggregates that are now computing a different value for each + // pivot value. + // TODO: Don't construct the physical container until after analysis. + case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) + } + if (filteredAggregate.fastEquals(aggregate)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$aggregate'") + } + Alias(filteredAggregate, outputName(value, aggregate))() } - val name = if (singleAgg) value.toString else value + "_" + aggregate.sql - Alias(filteredAggregate, name)() } + Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } - val newGroupByExprs = groupByExprs.map { - case UnresolvedAlias(e, _) => e - case e => e - } - Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala new file mode 100644 index 0000000000000..9154e96e34e9c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import scala.collection.immutable.HashMap + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ + +object PivotFirst { + + def supportsDataType(dataType: DataType): Boolean = updateFunction.isDefinedAt(dataType) + + // Currently UnsafeRow does not support the generic update method (throws + // UnsupportedOperationException), so we need to explicitly support each DataType. + private val updateFunction: PartialFunction[DataType, (MutableRow, Int, Any) => Unit] = { + case DoubleType => + (row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double]) + case IntegerType => + (row, offset, value) => row.setInt(offset, value.asInstanceOf[Int]) + case LongType => + (row, offset, value) => row.setLong(offset, value.asInstanceOf[Long]) + case FloatType => + (row, offset, value) => row.setFloat(offset, value.asInstanceOf[Float]) + case BooleanType => + (row, offset, value) => row.setBoolean(offset, value.asInstanceOf[Boolean]) + case ShortType => + (row, offset, value) => row.setShort(offset, value.asInstanceOf[Short]) + case ByteType => + (row, offset, value) => row.setByte(offset, value.asInstanceOf[Byte]) + case d: DecimalType => + (row, offset, value) => row.setDecimal(offset, value.asInstanceOf[Decimal], d.precision) + } +} + +/** + * PivotFirst is a aggregate function used in the second phase of a two phase pivot to do the + * required rearrangement of values into pivoted form. + * + * For example on an input of + * A | B + * --+-- + * x | 1 + * y | 2 + * z | 3 + * + * with pivotColumn=A, valueColumn=B, and pivotColumnValues=[z,y] the output is [3,2]. + * + * @param pivotColumn column that determines which output position to put valueColumn in. + * @param valueColumn the column that is being rearranged. + * @param pivotColumnValues the list of pivotColumn values in the order of desired output. Values + * not listed here will be ignored. + */ +case class PivotFirst( + pivotColumn: Expression, + valueColumn: Expression, + pivotColumnValues: Seq[Any], + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends ImperativeAggregate { + + override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil + + override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType) + + override val nullable: Boolean = false + + val valueDataType = valueColumn.dataType + + override val dataType: DataType = ArrayType(valueDataType) + + val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*) + + val indexSize = pivotIndex.size + + private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType) + + override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = { + val pivotColValue = pivotColumn.eval(inputRow) + if (pivotColValue != null) { + // We ignore rows whose pivot column value is not in the list of pivot column values. + val index = pivotIndex.getOrElse(pivotColValue, -1) + if (index >= 0) { + val value = valueColumn.eval(inputRow) + if (value != null) { + updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) + } + } + } + } + + override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = { + for (i <- 0 until indexSize) { + if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) { + val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType) + updateRow(mutableAggBuffer, mutableAggBufferOffset + i, value) + } + } + } + + override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match { + case d: DecimalType => + // Per doc of setDecimal we need to do this instead of setNullAt for DecimalType. + for (i <- 0 until indexSize) { + mutableAggBuffer.setDecimal(mutableAggBufferOffset + i, null, d.precision) + } + case _ => + for (i <- 0 until indexSize) { + mutableAggBuffer.setNullAt(mutableAggBufferOffset + i) + } + } + + override def eval(input: InternalRow): Any = { + val result = new Array[Any](indexSize) + for (i <- 0 until indexSize) { + result(i) = input.get(mutableAggBufferOffset + i, valueDataType) + } + new GenericArrayData(result) + } + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + + override lazy val aggBufferAttributes: Seq[AttributeReference] = + pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)()) + + override lazy val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 368aa5cd141f0..b17284aa94d2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ import testImplicits._ - test("pivot courses with literals") { + test("pivot courses") { checkAnswer( courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), @@ -32,14 +34,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ ) } - test("pivot year with literals") { + test("pivot year") { checkAnswer( courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } - test("pivot courses with literals and multiple aggregations") { + test("pivot courses with multiple aggregations") { checkAnswer( courseSales.groupBy($"year") .pivot("course", Seq("dotNET", "Java")) @@ -94,4 +96,88 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil ) } + + // Tests for optimized pivot (with PivotFirst) below + + test("optimized pivot planned") { + val df = courseSales.groupBy("year") + // pivot with extra columns to trigger optimization + .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) + .agg(sum($"earnings")) + val queryExecution = sqlContext.executePlan(df.queryExecution.logical) + assert(queryExecution.simpleString.contains("pivotfirst")) + } + + + test("optimized pivot courses with literals") { + checkAnswer( + courseSales.groupBy("year") + // pivot with extra columns to trigger optimization + .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) + .agg(sum($"earnings")) + .select("year", "dotNET", "Java"), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } + + test("optimized pivot year with literals") { + checkAnswer( + courseSales.groupBy($"course") + // pivot with extra columns to trigger optimization + .pivot("year", Seq(2012, 2013) ++ (1 to 10)) + .agg(sum($"earnings")) + .select("course", "2012", "2013"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("optimized pivot year with string values (cast)") { + checkAnswer( + courseSales.groupBy("course") + // pivot with extra columns to trigger optimization + .pivot("year", Seq("2012", "2013") ++ (1 to 10).map(_.toString)) + .sum("earnings") + .select("course", "2012", "2013"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("optimized pivot DecimalType") { + val df = courseSales.select($"course", $"year", $"earnings".cast(DecimalType(10, 2))) + .groupBy("year") + // pivot with extra columns to trigger optimization + .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) + .agg(sum($"earnings")) + .select("year", "dotNET", "Java") + + assertResult(IntegerType)(df.schema("year").dataType) + assertResult(DecimalType(20, 2))(df.schema("Java").dataType) + assertResult(DecimalType(20, 2))(df.schema("dotNET").dataType) + + checkAnswer(df, Row(2012, BigDecimal(1500000, 2), BigDecimal(2000000, 2)) :: + Row(2013, BigDecimal(4800000, 2), BigDecimal(3000000, 2)) :: Nil) + } + + test("PivotFirst supported datatypes") { + val supportedDataTypes: Seq[DataType] = DoubleType :: IntegerType :: LongType :: FloatType :: + BooleanType :: ShortType :: ByteType :: Nil + for (datatype <- supportedDataTypes) { + assertResult(true)(PivotFirst.supportsDataType(datatype)) + } + assertResult(true)(PivotFirst.supportsDataType(DecimalType(10, 1))) + assertResult(false)(PivotFirst.supportsDataType(null)) + assertResult(false)(PivotFirst.supportsDataType(ArrayType(IntegerType))) + } + + test("optimized pivot with multiple aggregations") { + checkAnswer( + courseSales.groupBy($"year") + // pivot with extra columns to trigger optimization + .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) + .agg(sum($"earnings"), avg($"earnings")), + Row(Seq(2012, 15000.0, 7500.0, 20000.0, 20000.0) ++ Seq.fill(20)(null): _*) :: + Row(Seq(2013, 48000.0, 48000.0, 30000.0, 30000.0) ++ Seq.fill(20)(null): _*) :: Nil + ) + } + }