Skip to content

Commit

Permalink
[SPARK-8992][SQL] Add pivot to dataframe api
Browse files Browse the repository at this point in the history
This adds a pivot method to the dataframe api.

Following the lead of cube and rollup this adds a Pivot operator that is translated into an Aggregate by the analyzer.

Currently the syntax is like:
~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~

~~Would we be interested in the following syntax also/alternatively? and~~

    courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
    //or
    courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings"))

Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right?

~~Also what would be the suggested Java friendly method signature for this?~~

Author: Andrew Ray <ray.andrew@gmail.com>

Closes #7841 from aray/sql-pivot.
  • Loading branch information
aray authored and yhuai committed Nov 12, 2015
1 parent 1a21be1 commit b8ff688
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Analyzer(
ResolveRelations ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
Expand Down Expand Up @@ -166,6 +167,10 @@ class Analyzer(
case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
g.withNewAggs(assignAliases(g.aggregations))

case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
if child.resolved && hasUnresolvedAlias(groupByExprs) =>
Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)

case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
Project(assignAliases(projectList), child)
}
Expand Down Expand Up @@ -248,6 +253,43 @@ class Analyzer(
}
}

object ResolvePivot extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Pivot if !p.childrenResolved => p
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
val singleAgg = aggregates.size == 1
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
def ifExpr(expr: Expression) = {
If(EqualTo(pivotColumn, value), expr, Literal(null))
}
aggregates.map { aggregate =>
val filteredAggregate = aggregate.transformDown {
// Assumption is the aggregate function ignores nulls. This is true for all current
// AggregateFunction's with the exception of First and Last in their default mode
// (which we handle) and possibly some Hive UDAF's.
case First(expr, _) =>
First(ifExpr(expr), Literal(true))
case Last(expr, _) =>
Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
s"Aggregate expression required for pivot, found '$aggregate'")
}
val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
Alias(filteredAggregate, name)()
}
}
val newGroupByExprs = groupByExprs.map {
case UnresolvedAlias(e) => e
case e => e
}
Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
}
}

/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,20 @@ case class Rollup(
this.copy(aggregations = aggs)
}

case class Pivot(
groupByExprs: Seq[NamedExpression],
pivotColumn: Expression,
pivotValues: Seq[Literal],
aggregates: Seq[Expression],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
case _ => pivotValues.flatMap{ value =>
aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
}
}
}

case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output

Expand Down
103 changes: 93 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.{StringType, NumericType}


/**
Expand All @@ -50,14 +50,8 @@ class GroupedData protected[sql](
aggExprs
}

val aliasedAgg = aggregates.map {
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
val aliasedAgg = aggregates.map(alias)

groupType match {
case GroupedData.GroupByType =>
DataFrame(
Expand All @@ -68,9 +62,22 @@ class GroupedData protected[sql](
case GroupedData.CubeType =>
DataFrame(
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
case GroupedData.PivotType(pivotCol, values) =>
val aliasedGrps = groupingExprs.map(alias)
DataFrame(
df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
}
}

// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
private[this] def alias(expr: Expression): NamedExpression = expr match {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}

private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
: DataFrame = {

Expand Down Expand Up @@ -273,6 +280,77 @@ class GroupedData protected[sql](
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum)
}

/**
* (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
* aggregation.
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
* // Or without specifying column values
* df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
* }}}
* @param pivotColumn Column to pivot
* @param values Optional list of values of pivotColumn that will be translated to columns in the
* output data frame. If values are not provided the method with do an immediate
* call to .distinct() on the pivot column.
* @since 1.6.0
*/
@scala.annotation.varargs
def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
case _: GroupedData.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case GroupedData.GroupByType =>
val pivotValues = if (values.nonEmpty) {
values.map {
case Column(literal: Literal) => literal
case other =>
throw new UnsupportedOperationException(
s"The values of a pivot must be literals, found $other")
}
} else {
// This is to prevent unintended OOM errors when the number of distinct values is large
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
// Get the distinct values of the column and sort them so its consistent
val values = df.select(pivotColumn)
.distinct()
.sort(pivotColumn)
.map(_.get(0))
.take(maxValues + 1)
.map(Literal(_)).toSeq
if (values.length > maxValues) {
throw new RuntimeException(
s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
"this could indicate an error. " +
"If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
s"to at least the number of distinct values of the pivot column.")
}
values
}
new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
case _ =>
throw new UnsupportedOperationException("pivot is only supported after a groupBy")
}

/**
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
* // Or without specifying column values
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
* @param pivotColumn Column to pivot
* @param values Optional list of values of pivotColumn that will be translated to columns in the
* output data frame. If values are not provided the method with do an immediate
* call to .distinct() on the pivot column.
* @since 1.6.0
*/
@scala.annotation.varargs
def pivot(pivotColumn: String, values: Any*): GroupedData = {
val resolvedPivotColumn = Column(df.resolve(pivotColumn))
pivot(resolvedPivotColumn, values.map(functions.lit): _*)
}
}


Expand Down Expand Up @@ -307,4 +385,9 @@ private[sql] object GroupedData {
* To indicate it's the ROLLUP
*/
private[sql] object RollupType extends GroupType

/**
* To indicate it's the PIVOT
*/
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
}
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,13 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)

val DATAFRAME_PIVOT_MAX_VALUES = intConf(
"spark.sql.pivotMaxValues",
defaultValue = Some(10000),
doc = "When doing a pivot without specifying values for the pivot column this is the maximum " +
"number of (distinct) values that will be collected without error."
)

val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
defaultValue = Some(true),
isPublic = false,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext

class DataFramePivotSuite extends QueryTest with SharedSQLContext{
import testImplicits._

test("pivot courses with literals") {
checkAnswer(
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
}

test("pivot year with literals") {
checkAnswer(
courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot courses with literals and multiple aggregations") {
checkAnswer(
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
.agg(sum($"earnings"), avg($"earnings")),
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
)
}

test("pivot year with string values (cast)") {
checkAnswer(
courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot year with int values") {
checkAnswer(
courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot courses with no values") {
// Note Java comes before dotNet in sorted order
checkAnswer(
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
)
}

test("pivot year with no values") {
checkAnswer(
courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot max values inforced") {
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
intercept[RuntimeException](
courseSales.groupBy($"year").pivot($"course")
)
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self =>
df
}

protected lazy val courseSales: DataFrame = {
val df = sqlContext.sparkContext.parallelize(
CourseSales("dotNET", 2012, 10000) ::
CourseSales("Java", 2012, 20000) ::
CourseSales("dotNET", 2012, 5000) ::
CourseSales("dotNET", 2013, 48000) ::
CourseSales("Java", 2013, 30000) :: Nil).toDF()
df.registerTempTable("courseSales")
df
}

/**
* Initialize all test data such that all temp tables are properly registered.
*/
Expand Down Expand Up @@ -295,4 +306,5 @@ private[sql] object SQLTestData {
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
case class CourseSales(course: String, year: Int, earnings: Double)
}

0 comments on commit b8ff688

Please sign in to comment.