Skip to content

Commit

Permalink
[SPARK-49201][PS][PYTHON][CONNECT] Reimplement hist plot with Spark…
Browse files Browse the repository at this point in the history
… SQL

### What changes were proposed in this pull request?
Reimplement `hist` plot with Spark SQL

### Why are the changes needed?
Reimplement `hist` plot with Spark SQL to support Spark Connect, since the `pyspark.ml.feature.Bucketizer` has not been supported with Spark Connect.

### Does this PR introduce _any_ user-facing change?
yes, follow plotting functions are enabled in Spark Connect:

- `{Frame, Series}.plot.hist`
- `{Frame, Series}.plot(kind="hist", ...)`

### How was this patch tested?
1, enabled parity tests;
2, manually check with:
```
import pyspark.pandas as ps
import pandas as pd
import numpy as np

np.random.seed(123)
df = pd.DataFrame(np.random.randint(1, 7, 6000), columns=['one'])
df['two'] = df['one'] + np.random.randint(1, 7, 6000)
df = ps.from_pandas(df)
df.plot.hist(bins=12, alpha=0.5)
```

before (Spark Classic):
![image](https://github.com/user-attachments/assets/4d99724f-0ca0-4871-8cb8-2c0774729f4f)

after (Spark Connect):
![image](https://github.com/user-attachments/assets/3df617d7-27ea-47d5-9e7f-42be3c26d92c)

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47708 from zhengruifeng/reimpl_hist.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Aug 13, 2024
1 parent 1113029 commit 70b814b
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 69 deletions.
39 changes: 25 additions & 14 deletions python/pyspark/pandas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import bisect
import importlib
import math

Expand All @@ -24,7 +25,7 @@
from pandas.core.dtypes.inference import is_integer

from pyspark.sql import functions as F, Column
from pyspark.sql.utils import is_remote
from pyspark.sql.types import DoubleType
from pyspark.pandas.missing import unsupported_function
from pyspark.pandas.config import get_option
from pyspark.pandas.utils import name_like_string
Expand Down Expand Up @@ -147,10 +148,9 @@ def get_bins(sdf, bins):

@staticmethod
def compute_hist(psdf, bins):
from pyspark.ml.feature import Bucketizer

# 'data' is a Spark DataFrame that selects one column.
assert isinstance(bins, (np.ndarray, np.generic))
assert len(bins) > 2, "the number of buckets must be higher than 2."

sdf = psdf._internal.spark_frame
scols = []
Expand Down Expand Up @@ -181,14 +181,31 @@ def compute_hist(psdf, bins):
colnames = sdf.columns
bucket_names = ["__{}_bucket".format(colname) for colname in colnames]

# TODO(SPARK-49202): register this function in scala side
@F.udf(returnType=DoubleType())
def binary_search_for_buckets(value):
# Given bins = [1.0, 2.0, 3.0, 4.0]
# the intervals are:
# [1.0, 2.0) -> 0.0
# [2.0, 3.0) -> 1.0
# [3.0, 4.0] -> 2.0 (the last bucket is a closed interval)
if value < bins[0] or value > bins[-1]:
raise ValueError(f"value {value} out of the bins bounds: [{bins[0]}, {bins[-1]}]")

if value == bins[-1]:
idx = len(bins) - 2
else:
idx = bisect.bisect(bins, value) - 1
return float(idx)

output_df = None
for group_id, (colname, bucket_name) in enumerate(zip(colnames, bucket_names)):
# creates a Bucketizer to get corresponding bin of each value
bucketizer = Bucketizer(
splits=bins, inputCol=colname, outputCol=bucket_name, handleInvalid="skip"
)
# sdf.na.drop to match handleInvalid="skip" in Bucketizer

bucket_df = bucketizer.transform(sdf)
bucket_df = sdf.na.drop(subset=[colname]).withColumn(
bucket_name,
binary_search_for_buckets(F.col(colname).cast("double")),
)

if output_df is None:
output_df = bucket_df.select(
Expand Down Expand Up @@ -595,9 +612,6 @@ def _get_plot_backend(backend=None):
return module

def __call__(self, kind="line", backend=None, **kwargs):
if is_remote() and kind == "hist":
return unsupported_function(class_name="pd.DataFrame", method_name=kind)()

plot_backend = PandasOnSparkPlotAccessor._get_plot_backend(backend)
plot_data = self.data

Expand Down Expand Up @@ -974,9 +988,6 @@ def hist(self, bins=10, **kwds):
>>> df = ps.from_pandas(df)
>>> df.plot.hist(bins=12, alpha=0.5) # doctest: +SKIP
"""
if is_remote():
return unsupported_function(class_name="pd.DataFrame", method_name="hist")()

return self(kind="hist", bins=bins, **kwds)

def kde(self, bw_method=None, ind=None, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
class DataFramePlotMatplotlibParityTests(
DataFramePlotMatplotlibTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_hist_plot(self):
super().test_hist_plot()
pass


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@
class DataFramePlotPlotlyParityTests(
DataFramePlotPlotlyTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_hist_layout_kwargs(self):
super().test_hist_layout_kwargs()

@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_hist_plot(self):
super().test_hist_plot()
pass


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,7 @@
class SeriesPlotMatplotlibParityTests(
SeriesPlotMatplotlibTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_empty_hist(self):
super().test_empty_hist()

@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_hist(self):
super().test_hist()

@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_hist_plot(self):
super().test_hist_plot()

@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_single_value_hist(self):
super().test_single_value_hist()
pass


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
class SeriesPlotPlotlyParityTests(
SeriesPlotPlotlyTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
@unittest.skip("Test depends on Spark ML which is not supported from Spark Connect.")
def test_hist_plot(self):
super().test_hist_plot()
pass


if __name__ == "__main__":
Expand Down
27 changes: 0 additions & 27 deletions python/pyspark/pandas/tests/connect/test_connect_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pandas as pd

from pyspark import pandas as ps
from pyspark.pandas.exceptions import PandasNotImplementedError
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils

Expand All @@ -37,32 +36,6 @@ def pdf1(self):
def psdf1(self):
return ps.from_pandas(self.pdf1)

def test_unsupported_functions(self):
with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot.hist()

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot.hist(bins=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot.hist()

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot.hist(bins=3)

def test_unsupported_kinds(self):
with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot(kind="hist")

with self.assertRaises(PandasNotImplementedError):
self.psdf1.plot(kind="hist", bins=3)

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot(kind="hist")

with self.assertRaises(PandasNotImplementedError):
self.psdf1.shield.plot(kind="hist", bins=3)


if __name__ == "__main__":
from pyspark.pandas.tests.connect.test_connect_plotting import * # noqa: F401
Expand Down

0 comments on commit 70b814b

Please sign in to comment.