Skip to content

Commit

Permalink
[SPARK-11690][PYSPARK] Add pivot to python api
Browse files Browse the repository at this point in the history
This PR adds pivot to the python api of GroupedData with the same syntax as Scala/Java.

Author: Andrew Ray <ray.andrew@gmail.com>

Closes #9653 from aray/sql-pivot-python.

(cherry picked from commit a244779)
Signed-off-by: Yin Huai <yhuai@databricks.com>
  • Loading branch information
aray authored and yhuai committed Nov 13, 2015
1 parent 4a1bcb2 commit 6459a67
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pyspark import since
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import *

Expand Down Expand Up @@ -167,6 +167,23 @@ def sum(self, *cols):
[Row(sum(age)=7, sum(height)=165)]
"""

@since(1.6)
def pivot(self, pivot_col, *values):
"""Pivots a column of the current DataFrame and preform the specified aggregation.
:param pivot_col: Column to pivot
:param values: Optional list of values of pivotColumn that will be translated to columns in
the output data frame. If values are not provided the method with do an immediate call
to .distinct() on the pivot column.
>>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect()
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
jgd = self._jdf.pivot(_to_java_column(pivot_col),
_to_seq(self.sql_ctx._sc, values, _create_column_from_literal))
return GroupedData(jgd, self.sql_ctx)


def _test():
import doctest
Expand All @@ -182,6 +199,11 @@ def _test():
StructField('name', StringType())]))
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
Row(name='Bob', age=5, height=85)]).toDF()
globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000),
Row(course="Java", year=2012, earnings=20000),
Row(course="dotNET", year=2012, earnings=5000),
Row(course="dotNET", year=2013, earnings=48000),
Row(course="Java", year=2013, earnings=30000)]).toDF()

(failure_count, test_count) = doctest.testmod(
pyspark.sql.group, globs=globs,
Expand Down

0 comments on commit 6459a67

Please sign in to comment.