From b8ff6888e76b437287d7d6bf2d4b9c759710a195 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 11 Nov 2015 16:23:24 -0800 Subject: [PATCH] [SPARK-8992][SQL] Add pivot to dataframe api This adds a pivot method to the dataframe api. Following the lead of cube and rollup this adds a Pivot operator that is translated into an Aggregate by the analyzer. Currently the syntax is like: ~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~ ~~Would we be interested in the following syntax also/alternatively? and~~ courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) //or courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")) Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right? ~~Also what would be the suggested Java friendly method signature for this?~~ Author: Andrew Ray Closes #7841 from aray/sql-pivot. --- .../sql/catalyst/analysis/Analyzer.scala | 42 +++++++ .../plans/logical/basicOperators.scala | 14 +++ .../org/apache/spark/sql/GroupedData.scala | 103 ++++++++++++++++-- .../scala/org/apache/spark/sql/SQLConf.scala | 7 ++ .../spark/sql/DataFramePivotSuite.scala | 87 +++++++++++++++ .../apache/spark/sql/test/SQLTestData.scala | 12 ++ 6 files changed, 255 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a9cd9a77038e7..2f4670b55bdba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -72,6 +72,7 @@ class Analyzer( ResolveRelations :: ResolveReferences :: ResolveGroupingAnalytics :: + ResolvePivot :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -166,6 +167,10 @@ class Analyzer( case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.withNewAggs(assignAliases(g.aggregations)) + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) + if child.resolved && hasUnresolvedAlias(groupByExprs) => + Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) } @@ -248,6 +253,43 @@ class Analyzer( } } + object ResolvePivot extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Pivot if !p.childrenResolved => p + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + val singleAgg = aggregates.size == 1 + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => + def ifExpr(expr: Expression) = { + If(EqualTo(pivotColumn, value), expr, Literal(null)) + } + aggregates.map { aggregate => + val filteredAggregate = aggregate.transformDown { + // Assumption is the aggregate function ignores nulls. This is true for all current + // AggregateFunction's with the exception of First and Last in their default mode + // (which we handle) and possibly some Hive UDAF's. + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) + case a: AggregateFunction => + a.withNewChildren(a.children.map(ifExpr)) + } + if (filteredAggregate.fastEquals(aggregate)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$aggregate'") + } + val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString + Alias(filteredAggregate, name)() + } + } + val newGroupByExprs = groupByExprs.map { + case UnresolvedAlias(e) => e + case e => e + } + Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) + } + } + /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 597f03e752707..32b09b59af436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -386,6 +386,20 @@ case class Rollup( this.copy(aggregations = aggs) } +case class Pivot( + groupByExprs: Seq[NamedExpression], + pivotColumn: Expression, + pivotValues: Seq[Literal], + aggregates: Seq[Expression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { + case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => pivotValues.flatMap{ value => + aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)()) + } + } +} + case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 5babf2cc0ca25..63dd7fbcbe9e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} -import org.apache.spark.sql.types.NumericType +import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} +import org.apache.spark.sql.types.{StringType, NumericType} /** @@ -50,14 +50,8 @@ class GroupedData protected[sql]( aggExprs } - val aliasedAgg = aggregates.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } + val aliasedAgg = aggregates.map(alias) + groupType match { case GroupedData.GroupByType => DataFrame( @@ -68,9 +62,22 @@ class GroupedData protected[sql]( case GroupedData.CubeType => DataFrame( df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + case GroupedData.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + DataFrame( + df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) } } + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + private[this] def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) : DataFrame = { @@ -273,6 +280,77 @@ class GroupedData protected[sql]( def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } + + /** + * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified + * aggregation. + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) + * // Or without specifying column values + * df.groupBy($"year").pivot($"course").agg(sum($"earnings")) + * }}} + * @param pivotColumn Column to pivot + * @param values Optional list of values of pivotColumn that will be translated to columns in the + * output data frame. If values are not provided the method with do an immediate + * call to .distinct() on the pivot column. + * @since 1.6.0 + */ + @scala.annotation.varargs + def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match { + case _: GroupedData.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case GroupedData.GroupByType => + val pivotValues = if (values.nonEmpty) { + values.map { + case Column(literal: Literal) => literal + case other => + throw new UnsupportedOperationException( + s"The values of a pivot must be literals, found $other") + } + } else { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) + .map(_.get(0)) + .take(maxValues + 1) + .map(Literal(_)).toSeq + if (values.length > maxValues) { + throw new RuntimeException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " + + s"to at least the number of distinct values of the pivot column.") + } + values + } + new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + + /** + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") + * // Or without specifying column values + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * @param pivotColumn Column to pivot + * @param values Optional list of values of pivotColumn that will be translated to columns in the + * output data frame. If values are not provided the method with do an immediate + * call to .distinct() on the pivot column. + * @since 1.6.0 + */ + @scala.annotation.varargs + def pivot(pivotColumn: String, values: Any*): GroupedData = { + val resolvedPivotColumn = Column(df.resolve(pivotColumn)) + pivot(resolvedPivotColumn, values.map(functions.lit): _*) + } } @@ -307,4 +385,9 @@ private[sql] object GroupedData { * To indicate it's the ROLLUP */ private[sql] object RollupType extends GroupType + + /** + * To indicate it's the PIVOT + */ + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e02b502b7b4d5..41d28d448ccc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -437,6 +437,13 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) + val DATAFRAME_PIVOT_MAX_VALUES = intConf( + "spark.sql.pivotMaxValues", + defaultValue = Some(10000), + doc = "When doing a pivot without specifying values for the pivot column this is the maximum " + + "number of (distinct) values that will be collected without error." + ) + val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", defaultValue = Some(true), isPublic = false, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala new file mode 100644 index 0000000000000..0c23d142670c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFramePivotSuite extends QueryTest with SharedSQLContext{ + import testImplicits._ + + test("pivot courses with literals") { + checkAnswer( + courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + .agg(sum($"earnings")), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } + + test("pivot year with literals") { + checkAnswer( + courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with literals and multiple aggregations") { + checkAnswer( + courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + .agg(sum($"earnings"), avg($"earnings")), + Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil + ) + } + + test("pivot year with string values (cast)") { + checkAnswer( + courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot year with int values") { + checkAnswer( + courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with no values") { + // Note Java comes before dotNet in sorted order + checkAnswer( + courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil + ) + } + + test("pivot year with no values") { + checkAnswer( + courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot max values inforced") { + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + intercept[RuntimeException]( + courseSales.groupBy($"year").pivot($"course") + ) + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 520dea7f7dd92..abad0d7eaaedf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val courseSales: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + CourseSales("dotNET", 2012, 10000) :: + CourseSales("Java", 2012, 20000) :: + CourseSales("dotNET", 2012, 5000) :: + CourseSales("dotNET", 2013, 48000) :: + CourseSales("Java", 2013, 30000) :: Nil).toDF() + df.registerTempTable("courseSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -295,4 +306,5 @@ private[sql] object SQLTestData { case class Person(id: Int, name: String, age: Int) case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) + case class CourseSales(course: String, year: Int, earnings: Double) }