Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Sep 3, 2024
1 parent 32b054c commit db4cab6
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 20 deletions.
46 changes: 26 additions & 20 deletions python/pyspark/pandas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pandas.core.dtypes.inference import is_integer

from pyspark.sql import functions as F, Column
from pyspark.sql.types import DoubleType
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.missing import unsupported_function
from pyspark.pandas.config import get_option
Expand Down Expand Up @@ -182,33 +181,40 @@ def compute_hist(psdf, bins):
colnames = sdf.columns
bucket_names = ["__{}_bucket".format(colname) for colname in colnames]

# TODO(SPARK-49202): register this function in scala side
@F.udf(returnType=DoubleType())
def binary_search_for_buckets(value):
# Given bins = [1.0, 2.0, 3.0, 4.0]
# the intervals are:
# [1.0, 2.0) -> 0.0
# [2.0, 3.0) -> 1.0
# [3.0, 4.0] -> 2.0 (the last bucket is a closed interval)
if value < bins[0] or value > bins[-1]:
raise ValueError(f"value {value} out of the bins bounds: [{bins[0]}, {bins[-1]}]")

if value == bins[-1]:
idx = len(bins) - 2
else:
idx = bisect.bisect(bins, value) - 1
return float(idx)
# refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets
def binary_search_for_buckets(value: Column):
index = SF.binary_search(F.lit(bins), value)
bucket = (
F.when(value == F.lit(bins[-1]), F.lit(len(bins) - 2))
.when(index > F.lit(0), index)
.otherwise(-index - F.lit(2))
)

return (
F.when(value.between(F.lit(bins[0]), F.lit(bins[-1])), bucket)
.when(value.isNaN(), F.raise_error(F.lit("Histogram encountered NaN value.")))
.otherwise(
F.raise_error(
F.printf(
F.lit("value %s out of the bins bounds: [%s, %s]"),
value,
F.lit(bins[0]),
F.lit(bins[-1]),
)
)
)
)

output_df = (
sdf.select(
F.posexplode(
F.array([F.col(colname).cast("double") for colname in colnames])
).alias("__group_id", "__value")
)
# to match handleInvalid="skip" in Bucketizer
.where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()).select(
.where(F.col("__value").isNotNull() & ~F.col("__value").isNaN())
.select(
F.col("__group_id"),
binary_search_for_buckets(F.col("__value")).alias("__bucket"),
binary_search_for_buckets(F.col("__value")).cast("double").alias("__bucket"),
)
)

Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,19 @@ def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse))


def binary_search(col: Column, value: Column) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

return _invoke_function_over_columns("array_binary_search", col, value)

else:
from pyspark import SparkContext

sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc))


def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
unit_mapping = {
"YEAR": "years",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ private[sql] object PythonSQLUtils extends Logging {
def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))

def binary_search(e: Column, value: Column): Column =
Column.internalFn("array_binary_search", e, value)

def pandasProduct(e: Column, ignoreNA: Boolean): Column =
Column.internalFn("pandas_product", e, lit(ignoreNA))

Expand Down

0 comments on commit db4cab6

Please sign in to comment.