Skip to content

Commit

Permalink
[SPARK-51079][PYTHON] Support large variable types in pandas UDF, cre…
Browse files Browse the repository at this point in the history
…ateDataFrame and toPandas with Arrow

### What changes were proposed in this pull request?

This PR is a retry of apache#41569 that implements to use large variable types within PySpark everywhere.

apache#39572 implemented the core logic but it only supports large variable types in the bold cases below:

- `mapInArrow`: **JVM -> Python -> JVM**
- Pandas UDF/Function API: **JVM -> Python** -> JVM
- createDataFrame with Arrow: Python -> JVM
- toPandas with Arrow: JVM -> Python

This PR completes them all.

### Why are the changes needed?

To consistently support the large variable types.

### Does this PR introduce _any_ user-facing change?

`spark.sql.execution.arrow.useLargeVarTypes` is not released out yet so it doesn't affect any end users.

### How was this patch tested?

Existing tests with `spark.sql.execution.arrow.useLargeVarTypes` enabled.

Closes apache#49790 from HyukjinKwon/SPARK-39979-followup2.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Feb 5, 2025
1 parent 940eb36 commit e2ef5a4
Show file tree
Hide file tree
Showing 33 changed files with 372 additions and 117 deletions.
10 changes: 9 additions & 1 deletion python/pyspark/pandas/typedef/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,15 @@ def spark_type_to_pandas_dtype(
elif isinstance(spark_type, (types.TimestampType, types.TimestampNTZType)):
return np.dtype("datetime64[ns]")
else:
return np.dtype(to_arrow_type(spark_type).to_pandas_dtype())
from pyspark.pandas.utils import default_session

prefers_large_var_types = (
default_session()
.conf.get("spark.sql.execution.arrow.useLargeVarTypes", "false")
.lower()
== "true"
)
return np.dtype(to_arrow_type(spark_type, prefers_large_var_types).to_pandas_dtype())


def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.DataType]:
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def convert_other(value: Any) -> Any:
return lambda value: value

@staticmethod
def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool) -> "pa.Table":
assert isinstance(data, list) and len(data) > 0

assert schema is not None and isinstance(schema, StructType)
Expand Down Expand Up @@ -372,7 +372,8 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
)
for field in schema.fields
]
)
),
prefers_large_types=use_large_var_types,
)

return pa.Table.from_arrays(pylist, schema=pa_schema)
Expand Down
25 changes: 21 additions & 4 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,13 @@ def createDataFrame(
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
"spark.sql.execution.arrow.useLargeVarTypes",
)
timezone = configs["spark.sql.session.timeZone"]
prefer_timestamp = configs["spark.sql.timestampType"]
prefers_large_types: bool = (
cast(str, configs["spark.sql.execution.arrow.useLargeVarTypes"]).lower() == "true"
)

_table: Optional[pa.Table] = None

Expand Down Expand Up @@ -552,7 +556,9 @@ def createDataFrame(
if isinstance(schema, StructType):
deduped_schema = cast(StructType, _deduplicate_field_names(schema))
spark_types = [field.dataType for field in deduped_schema.fields]
arrow_schema = to_arrow_schema(deduped_schema)
arrow_schema = to_arrow_schema(
deduped_schema, prefers_large_types=prefers_large_types
)
arrow_types = [field.type for field in arrow_schema]
_cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()]
elif isinstance(schema, DataType):
Expand All @@ -570,7 +576,12 @@ def createDataFrame(
else None
for t in data.dtypes
]
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]
arrow_types = [
to_arrow_type(dt, prefers_large_types=prefers_large_types)
if dt is not None
else None
for dt in spark_types
]

safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"]

Expand Down Expand Up @@ -609,7 +620,13 @@ def createDataFrame(

_table = (
_check_arrow_table_timestamps_localize(data, schema, True, timezone)
.cast(to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True))
.cast(
to_arrow_schema(
schema,
error_on_duplicated_field_names_in_struct=True,
prefers_large_types=prefers_large_types,
)
)
.rename_columns(schema.names)
)

Expand Down Expand Up @@ -684,7 +701,7 @@ def createDataFrame(
# Spark Connect will try its best to build the Arrow table with the
# inferred schema in the client side, and then rename the columns and
# cast the datatypes in the server side.
_table = LocalDataToArrowConversion.convert(_data, _schema)
_table = LocalDataToArrowConversion.convert(_data, _schema, prefers_large_types)

# TODO: Beside the validation on number of columns, we should also check
# whether the Arrow Schema is compatible with the user provided Schema.
Expand Down
28 changes: 23 additions & 5 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def toPandas(self) -> "PandasDataFrameLike":
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()
to_arrow_schema(self.schema)
to_arrow_schema(self.schema, prefers_large_types=jconf.arrowUseLargeVarTypes())
except Exception as e:
if jconf.arrowPySparkFallbackEnabled():
msg = (
Expand Down Expand Up @@ -236,7 +236,12 @@ def toArrow(self) -> "pa.Table":
from pyspark.sql.pandas.utils import require_minimum_pyarrow_version

require_minimum_pyarrow_version()
schema = to_arrow_schema(self.schema, error_on_duplicated_field_names_in_struct=True)
prefers_large_var_types = jconf.arrowUseLargeVarTypes()
schema = to_arrow_schema(
self.schema,
error_on_duplicated_field_names_in_struct=True,
prefers_large_types=prefers_large_var_types,
)

import pyarrow as pa

Expand Down Expand Up @@ -322,7 +327,8 @@ def _collect_as_arrow(
from pyspark.sql.pandas.types import to_arrow_schema
import pyarrow as pa

schema = to_arrow_schema(self.schema)
prefers_large_var_types = self.sparkSession._jconf.arrowUseLargeVarTypes()
schema = to_arrow_schema(self.schema, prefers_large_types=prefers_large_var_types)
empty_arrays = [pa.array([], type=field.type) for field in schema]
return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)]

Expand Down Expand Up @@ -715,9 +721,16 @@ def _create_from_pandas_with_arrow(
pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step))

# Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream
prefers_large_var_types = self._jconf.arrowUseLargeVarTypes()
arrow_data = [
[
(c, to_arrow_type(t) if t is not None else None, t)
(
c,
to_arrow_type(t, prefers_large_types=prefers_large_var_types)
if t is not None
else None,
t,
)
for (_, c), t in zip(pdf_slice.items(), spark_types)
]
for pdf_slice in pdf_slices
Expand Down Expand Up @@ -785,8 +798,13 @@ def _create_from_arrow_table(
if not isinstance(schema, StructType):
schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=prefer_timestamp_ntz)

prefers_large_var_types = self._jconf.arrowUseLargeVarTypes()
table = _check_arrow_table_timestamps_localize(table, schema, True, timezone).cast(
to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True)
to_arrow_schema(
schema,
error_on_duplicated_field_names_in_struct=True,
prefers_large_types=prefers_large_var_types,
)
)

# Chunk the Arrow Table into RecordBatches
Expand Down
9 changes: 7 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ def __init__(
assign_cols_by_name,
state_object_schema,
arrow_max_records_per_batch,
prefers_large_var_types,
):
super(ApplyInPandasWithStateSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name
Expand All @@ -808,7 +809,9 @@ def __init__(
]
)

self.result_count_pdf_arrow_type = to_arrow_type(self.result_count_df_type)
self.result_count_pdf_arrow_type = to_arrow_type(
self.result_count_df_type, prefers_large_types=prefers_large_var_types
)

self.result_state_df_type = StructType(
[
Expand All @@ -819,7 +822,9 @@ def __init__(
]
)

self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type)
self.result_state_pdf_arrow_type = to_arrow_type(
self.result_state_df_type, prefers_large_types=prefers_large_var_types
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch

def load_stream(self, stream):
Expand Down
44 changes: 38 additions & 6 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def to_arrow_type(
dt: DataType,
error_on_duplicated_field_names_in_struct: bool = False,
timestamp_utc: bool = True,
prefers_large_types: bool = False,
) -> "pa.DataType":
"""
Convert Spark data type to PyArrow type
Expand Down Expand Up @@ -107,8 +108,12 @@ def to_arrow_type(
arrow_type = pa.float64()
elif type(dt) == DecimalType:
arrow_type = pa.decimal128(dt.precision, dt.scale)
elif type(dt) == StringType and prefers_large_types:
arrow_type = pa.large_string()
elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == BinaryType and prefers_large_types:
arrow_type = pa.large_binary()
elif type(dt) == BinaryType:
arrow_type = pa.binary()
elif type(dt) == DateType:
Expand All @@ -125,19 +130,34 @@ def to_arrow_type(
elif type(dt) == ArrayType:
field = pa.field(
"element",
to_arrow_type(dt.elementType, error_on_duplicated_field_names_in_struct, timestamp_utc),
to_arrow_type(
dt.elementType,
error_on_duplicated_field_names_in_struct,
timestamp_utc,
prefers_large_types,
),
nullable=dt.containsNull,
)
arrow_type = pa.list_(field)
elif type(dt) == MapType:
key_field = pa.field(
"key",
to_arrow_type(dt.keyType, error_on_duplicated_field_names_in_struct, timestamp_utc),
to_arrow_type(
dt.keyType,
error_on_duplicated_field_names_in_struct,
timestamp_utc,
prefers_large_types,
),
nullable=False,
)
value_field = pa.field(
"value",
to_arrow_type(dt.valueType, error_on_duplicated_field_names_in_struct, timestamp_utc),
to_arrow_type(
dt.valueType,
error_on_duplicated_field_names_in_struct,
timestamp_utc,
prefers_large_types,
),
nullable=dt.valueContainsNull,
)
arrow_type = pa.map_(key_field, value_field)
Expand All @@ -152,7 +172,10 @@ def to_arrow_type(
pa.field(
field.name,
to_arrow_type(
field.dataType, error_on_duplicated_field_names_in_struct, timestamp_utc
field.dataType,
error_on_duplicated_field_names_in_struct,
timestamp_utc,
prefers_large_types,
),
nullable=field.nullable,
)
Expand All @@ -163,7 +186,10 @@ def to_arrow_type(
arrow_type = pa.null()
elif isinstance(dt, UserDefinedType):
arrow_type = to_arrow_type(
dt.sqlType(), error_on_duplicated_field_names_in_struct, timestamp_utc
dt.sqlType(),
error_on_duplicated_field_names_in_struct,
timestamp_utc,
prefers_large_types,
)
elif type(dt) == VariantType:
fields = [
Expand All @@ -185,6 +211,7 @@ def to_arrow_schema(
schema: StructType,
error_on_duplicated_field_names_in_struct: bool = False,
timestamp_utc: bool = True,
prefers_large_types: bool = False,
) -> "pa.Schema":
"""
Convert a schema from Spark to Arrow
Expand Down Expand Up @@ -212,7 +239,12 @@ def to_arrow_schema(
fields = [
pa.field(
field.name,
to_arrow_type(field.dataType, error_on_duplicated_field_names_in_struct, timestamp_utc),
to_arrow_type(
field.dataType,
error_on_duplicated_field_names_in_struct,
timestamp_utc,
prefers_large_types,
),
nullable=field.nullable,
)
for field in schema
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/sql/tests/arrow/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,11 @@ def test_createDataFrame_arrow_truncate_timestamp(self):
def test_schema_conversion_roundtrip(self):
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema

arrow_schema = to_arrow_schema(self.schema)
arrow_schema = to_arrow_schema(self.schema, prefers_large_types=False)
schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
self.assertEqual(self.schema, schema_rt)

arrow_schema = to_arrow_schema(self.schema, prefers_large_types=True)
schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True)
self.assertEqual(self.schema, schema_rt)

Expand Down
Loading

0 comments on commit e2ef5a4

Please sign in to comment.