Skip to content

Commit

Permalink
[SPARK-20584][PYSPARK][SQL] Python generic hint support
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Adds `hint` method to PySpark `DataFrame`.

## How was this patch tested?

Unit tests, doctests.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17850 from zero323/SPARK-20584.
  • Loading branch information
zero323 authored and rxin committed May 4, 2017
1 parent 13eb37c commit 02bbe73
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
29 changes: 29 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,35 @@ def withWatermark(self, eventTime, delayThreshold):
jdf = self._jdf.withWatermark(eventTime, delayThreshold)
return DataFrame(jdf, self.sql_ctx)

@since(2.2)
def hint(self, name, *parameters):
"""Specifies some hint on the current DataFrame.
:param name: A name of the hint.
:param parameters: Optional parameters.
:return: :class:`DataFrame`
>>> df.join(df2.hint("broadcast"), "name").show()
+----+---+------+
|name|age|height|
+----+---+------+
| Bob| 5| 85|
+----+---+------+
"""
if len(parameters) == 1 and isinstance(parameters[0], list):
parameters = parameters[0]

if not isinstance(name, str):
raise TypeError("name should be provided as str, got {0}".format(type(name)))

for p in parameters:
if not isinstance(p, str):
raise TypeError(
"all parameters should be str, got {0} of type {1}".format(p, type(p)))

jdf = self._jdf.hint(name, self._jseq(parameters))
return DataFrame(jdf, self.sql_ctx)

@since(1.3)
def count(self):
"""Returns the number of rows in this :class:`DataFrame`.
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,22 @@ def test_functions_broadcast(self):
# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()

def test_generic_hints(self):
from pyspark.sql import DataFrame

df1 = self.spark.range(10e10).toDF("id")
df2 = self.spark.range(10e10).toDF("id")

self.assertIsInstance(df1.hint("broadcast"), DataFrame)
self.assertIsInstance(df1.hint("broadcast", []), DataFrame)

# Dummy rules
self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)

plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))

def test_toDF_with_schema_string(self):
data = [Row(key=i, value=str(i)) for i in range(100)]
rdd = self.sc.parallelize(data, 5)
Expand Down

0 comments on commit 02bbe73

Please sign in to comment.