Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-39979][SQL][FOLLOW-UP] Support large variable types in pandas UDF, createDataFrame and toPandas with Arrow #41569

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/pyspark/pandas/typedef/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,13 @@ def spark_type_to_pandas_dtype(
elif isinstance(spark_type, types.TimestampType):
return np.dtype("datetime64[ns]")
else:
return np.dtype(to_arrow_type(spark_type).to_pandas_dtype())
from pyspark.pandas.utils import default_session

prefers_large_var_types = (
default_session().conf.get("spark.sql.execution.arrow.useLargeVarTypes", "true")
== "true"
)
return np.dtype(to_arrow_type(spark_type, prefers_large_var_types).to_pandas_dtype())


def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.DataType]:
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def convert_udt(value: Any) -> Any:
return lambda value: value

@staticmethod
def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool) -> "pa.Table":
assert isinstance(data, list) and len(data) > 0

assert schema is not None and isinstance(schema, StructType)
Expand Down Expand Up @@ -296,7 +296,8 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
)
for field in schema.fields
]
)
),
use_large_var_types,
)

return pa.Table.from_arrays(pylist, schema=pa_schema)
Expand Down
15 changes: 12 additions & 3 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,16 @@ def createDataFrame(
_num_cols = len(_cols)

# Determine arrow types to coerce data when creating batches
prefers_large_types: bool = (
self._client.get_configs("spark.sql.execution.arrow.useLargeVarTypes")[0] == "true"
)
arrow_schema: Optional[pa.Schema] = None
spark_types: List[Optional[DataType]]
arrow_types: List[Optional[pa.DataType]]
if isinstance(schema, StructType):
deduped_schema = cast(StructType, _deduplicate_field_names(schema))
spark_types = [field.dataType for field in deduped_schema.fields]
arrow_schema = to_arrow_schema(deduped_schema)
arrow_schema = to_arrow_schema(deduped_schema, prefers_large_types)
arrow_types = [field.type for field in arrow_schema]
_cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()]
elif isinstance(schema, DataType):
Expand All @@ -362,7 +365,10 @@ def createDataFrame(
else None
for t in data.dtypes
]
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]
arrow_types = [
to_arrow_type(dt, prefers_large_types) if dt is not None else None
for dt in spark_types
]

timezone, safecheck = self._client.get_configs(
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
Expand Down Expand Up @@ -459,7 +465,10 @@ def createDataFrame(
# Spark Connect will try its best to build the Arrow table with the
# inferred schema in the client side, and then rename the columns and
# cast the datatypes in the server side.
_table = LocalDataToArrowConversion.convert(_data, _schema)
prefers_large_types: bool = (
self._client.get_configs("spark.sql.execution.arrow.useLargeVarTypes")[0] == "true"
)
_table = LocalDataToArrowConversion.convert(_data, _schema, prefers_large_types)

# TODO: Beside the validation on number of columns, we should also check
# whether the Arrow Schema is compatible with the user provided Schema.
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def toPandas(self) -> "PandasDataFrameLike":
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()
to_arrow_schema(self.schema)
to_arrow_schema(self.schema, jconf.arrowUseLargeVarTypes())
except Exception as e:

if jconf.arrowPySparkFallbackEnabled():
Expand Down Expand Up @@ -616,9 +616,10 @@ def _create_from_pandas_with_arrow(
pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step))

# Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream
prefers_large_var_types = self._jconf.arrowUseLargeVarTypes()
arrow_data = [
[
(c, to_arrow_type(t) if t is not None else None, t)
(c, to_arrow_type(t, prefers_large_var_types) if t is not None else None, t)
for (_, c), t in zip(pdf_slice.items(), spark_types)
]
for pdf_slice in pdf_slices
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def __init__(
assign_cols_by_name,
state_object_schema,
arrow_max_records_per_batch,
prefers_large_var_types,
):
super(ApplyInPandasWithStateSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name
Expand All @@ -495,7 +496,9 @@ def __init__(
]
)

self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type)
self.result_state_pdf_arrow_type = to_arrow_type(
self.result_state_df_type, prefers_large_var_types
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch

def load_stream(self, stream):
Expand Down
30 changes: 22 additions & 8 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from pyspark.sql.pandas._typing import SeriesLike as PandasSeriesLike


def to_arrow_type(dt: DataType) -> "pa.DataType":
def to_arrow_type(dt: DataType, prefers_large_types: bool = False) -> "pa.DataType":
"""Convert Spark data type to pyarrow type"""
from distutils.version import LooseVersion
import pyarrow as pa
Expand All @@ -80,8 +80,12 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
arrow_type = pa.float64()
elif type(dt) == DecimalType:
arrow_type = pa.decimal128(dt.precision, dt.scale)
elif type(dt) == StringType and prefers_large_types:
arrow_type = pa.large_string()
elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == BinaryType and prefers_large_types:
arrow_type = pa.large_binary()
elif type(dt) == BinaryType:
arrow_type = pa.binary()
elif type(dt) == DateType:
Expand All @@ -101,27 +105,35 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_VERSION",
message_parameters={"data_type": "Array of StructType"},
)
field = pa.field("element", to_arrow_type(dt.elementType), nullable=dt.containsNull)
field = pa.field(
"element", to_arrow_type(dt.elementType, prefers_large_types), nullable=dt.containsNull
)
arrow_type = pa.list_(field)
elif type(dt) == MapType:
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_VERSION",
message_parameters={"data_type": "MapType"},
)
key_field = pa.field("key", to_arrow_type(dt.keyType), nullable=False)
value_field = pa.field("value", to_arrow_type(dt.valueType), nullable=dt.valueContainsNull)
key_field = pa.field("key", to_arrow_type(dt.keyType, prefers_large_types), nullable=False)
value_field = pa.field(
"value", to_arrow_type(dt.valueType, prefers_large_types), nullable=dt.valueContainsNull
)
arrow_type = pa.map_(key_field, value_field)
elif type(dt) == StructType:
fields = [
pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
pa.field(
field.name,
to_arrow_type(field.dataType, prefers_large_types),
nullable=field.nullable,
)
for field in dt
]
arrow_type = pa.struct(fields)
elif type(dt) == NullType:
arrow_type = pa.null()
elif isinstance(dt, UserDefinedType):
arrow_type = to_arrow_type(dt.sqlType())
arrow_type = to_arrow_type(dt.sqlType(), prefers_large_types)
else:
raise PySparkTypeError(
error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
Expand All @@ -130,12 +142,14 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
return arrow_type


def to_arrow_schema(schema: StructType) -> "pa.Schema":
def to_arrow_schema(schema: StructType, prefers_large_types: bool = False) -> "pa.Schema":
"""Convert a schema from Spark to Arrow"""
import pyarrow as pa

fields = [
pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
pa.field(
field.name, to_arrow_type(field.dataType, prefers_large_types), nullable=field.nullable
)
for field in schema
]
return pa.schema(fields)
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,11 @@ def test_createDataFrame_does_not_modify_input(self):
def test_schema_conversion_roundtrip(self):
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema

arrow_schema = to_arrow_schema(self.schema)
arrow_schema = to_arrow_schema(self.schema, prefers_large_types=False)
schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
self.assertEqual(self.schema, schema_rt)

arrow_schema = to_arrow_schema(self.schema, prefers_large_types=True)
schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
self.assertEqual(self.schema, schema_rt)

Expand Down
43 changes: 27 additions & 16 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def wrap_udf(f, return_type):
return lambda *a: f(*a)


def wrap_scalar_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
def wrap_scalar_pandas_udf(f, return_type, runner_conf):
arrow_return_type = to_arrow_type(return_type, use_large_var_types(runner_conf))

def verify_result_type(result):
if not hasattr(result, "__len__"):
Expand Down Expand Up @@ -133,8 +133,8 @@ def verify_result_length(result, length):
)


def wrap_batch_iter_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
def wrap_batch_iter_udf(f, return_type, runner_conf):
arrow_return_type = to_arrow_type(return_type, use_large_var_types(runner_conf))

def verify_result_type(result):
if not hasattr(result, "__len__"):
Expand Down Expand Up @@ -196,6 +196,7 @@ def verify_pandas_result(result, return_type, assign_cols_by_name):


def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
_use_large_var_types = use_large_var_types(runner_conf)
_assign_cols_by_name = assign_cols_by_name(runner_conf)

def wrapped(left_key_series, left_value_series, right_key_series, right_value_series):
Expand All @@ -214,10 +215,13 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se

return result

return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), to_arrow_type(return_type))]
return lambda kl, vl, kr, vr: [
(wrapped(kl, vl, kr, vr), to_arrow_type(return_type, _use_large_var_types))
]


def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
_use_large_var_types = use_large_var_types(runner_conf)
_assign_cols_by_name = assign_cols_by_name(runner_conf)

def wrapped(key_series, value_series):
Expand All @@ -232,10 +236,10 @@ def wrapped(key_series, value_series):

return result

return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type, _use_large_var_types))]


def wrap_grouped_map_pandas_udf_with_state(f, return_type):
def wrap_grouped_map_pandas_udf_with_state(f, return_type, runner_conf):
"""
Provides a new lambda instance wrapping user function of applyInPandasWithState.

Expand All @@ -250,6 +254,7 @@ def wrap_grouped_map_pandas_udf_with_state(f, return_type):
Along with the returned iterator, the lambda instance will also produce the return_type as
converted to the arrow schema.
"""
_use_large_var_types = use_large_var_types(runner_conf)

def wrapped(key_series, value_series_gen, state):
"""
Expand Down Expand Up @@ -318,11 +323,11 @@ def verify_element(result):
state,
)

return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))]
return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type, _use_large_var_types))]


def wrap_grouped_agg_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
def wrap_grouped_agg_pandas_udf(f, return_type, runner_conf):
arrow_return_type = to_arrow_type(return_type, use_large_var_types(runner_conf))

def wrapped(*series):
import pandas as pd
Expand Down Expand Up @@ -420,23 +425,23 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):

# the last returnType will be the return type of UDF
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF):
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
return arg_offsets, wrap_scalar_pandas_udf(func, return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
return arg_offsets, wrap_batch_iter_udf(func, return_type)
return arg_offsets, wrap_batch_iter_udf(func, return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
return arg_offsets, wrap_batch_iter_udf(func, return_type)
return arg_offsets, wrap_batch_iter_udf(func, return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
return arg_offsets, wrap_batch_iter_udf(func, return_type)
return arg_offsets, wrap_batch_iter_udf(func, return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when wrapping it
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type)
return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when wrapping it
return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type, runner_conf)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
Expand All @@ -456,6 +461,10 @@ def assign_cols_by_name(runner_conf):
)


def use_large_var_types(runner_conf):
return runner_conf.get("spark.sql.execution.arrow.useLargeVarTypes", "true").lower() == "true"


# Read and process a serialized user-defined table function (UDTF) from a socket.
# It expects the UDTF to be in a specific format and performs various checks to
# ensure the UDTF is valid. This function also prepares a mapper function for applying
Expand Down Expand Up @@ -563,6 +572,7 @@ def read_udfs(pickleSer, infile, eval_type):

# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
timezone = runner_conf.get("spark.sql.session.timeZone", None)
prefers_large_var_types = use_large_var_types(runner_conf)
safecheck = (
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower()
== "true"
Expand All @@ -582,6 +592,7 @@ def read_udfs(pickleSer, infile, eval_type):
assign_cols_by_name(runner_conf),
state_object_schema,
arrow_max_records_per_batch,
prefers_large_var_types,
)
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
ser = ArrowStreamUDFSerializer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2834,7 +2834,7 @@ object SQLConf {
"usage per value.")
.version("3.5.0")
.booleanConf
.createWithDefault(false)
.createWithDefault(true)

val PANDAS_UDF_BUFFER_SIZE =
buildConf("spark.sql.execution.pandas.udf.buffer.size")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ private[sql] object ArrowUtils {
conf.pandasGroupedMapAssignColumnsByName.toString)
val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
conf.arrowSafeTypeConversion.toString)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
val useLargeVarTypes = Seq(SQLConf.ARROW_EXECUTION_USE_LARGE_VAR_TYPES.key ->
conf.arrowSafeTypeConversion.toString)
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++ useLargeVarTypes: _*)
}

private def deduplicateFieldNames(
Expand Down