From 3dee0b585a0f5ad1de98008d5cfa826e1176d1cb Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 6 Feb 2025 08:57:52 +0900 Subject: [PATCH] [SPARK-51079][PYTHON] Support large variable types in pandas UDF, createDataFrame and toPandas with Arrow ### What changes were proposed in this pull request? This PR is a retry of https://github.com/apache/spark/pull/41569 that implements to use large variable types within PySpark everywhere. https://github.com/apache/spark/pull/39572 implemented the core logic but it only supports large variable types in the bold cases below: - `mapInArrow`: **JVM -> Python -> JVM** - Pandas UDF/Function API: **JVM -> Python** -> JVM - createDataFrame with Arrow: Python -> JVM - toPandas with Arrow: JVM -> Python This PR completes them all. ### Why are the changes needed? To consistently support the large variable types. ### Does this PR introduce _any_ user-facing change? `spark.sql.execution.arrow.useLargeVarTypes` is not released out yet so it doesn't affect any end users. ### How was this patch tested? Existing tests with `spark.sql.execution.arrow.useLargeVarTypes` enabled. Closes #49790 from HyukjinKwon/SPARK-39979-followup2. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/typedef/typehints.py | 10 +- python/pyspark/sql/connect/conversion.py | 5 +- python/pyspark/sql/connect/session.py | 25 ++++- python/pyspark/sql/pandas/conversion.py | 28 +++++- python/pyspark/sql/pandas/serializers.py | 9 +- python/pyspark/sql/pandas/types.py | 44 +++++++-- python/pyspark/sql/tests/arrow/test_arrow.py | 6 +- python/pyspark/worker.py | 95 +++++++++++++------ .../spark/sql/internal/SqlApiConf.scala | 2 + .../spark/sql/internal/SqlApiConfHelper.scala | 1 + .../apache/spark/sql/util/ArrowUtils.scala | 2 +- .../sql/execution/arrow/ArrowWriter.scala | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../spark/sql/util/ArrowUtilsSuite.scala | 9 +- .../sql/connect/SQLImplicitsTestSuite.scala | 3 +- .../client/arrow/ArrowEncoderSuite.scala | 9 +- .../spark/sql/connect/SparkSession.scala | 6 +- .../client/arrow/ArrowSerializer.scala | 44 +++++++-- .../client/arrow/ArrowVectorReader.scala | 21 +++- .../execution/SparkConnectPlanExecution.scala | 13 ++- .../connect/planner/SparkConnectPlanner.scala | 10 +- .../sql/connect/SparkConnectServerTest.scala | 7 +- .../planner/SparkConnectPlannerSuite.scala | 5 +- .../planner/SparkConnectProtoSuite.scala | 3 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 8 +- .../apache/spark/sql/classic/Dataset.scala | 17 +++- .../sql/execution/arrow/ArrowConverters.scala | 52 +++++++--- .../execution/python/ArrowPythonRunner.scala | 5 +- .../python/CoGroupedArrowPythonRunner.scala | 4 +- .../python/FlatMapCoGroupsInBatchExec.scala | 2 + .../spark/sql/execution/r/ArrowRRunner.scala | 3 +- .../arrow/ArrowConvertersSuite.scala | 29 +++--- .../execution/arrow/ArrowWriterSuite.scala | 3 +- 33 files changed, 372 insertions(+), 117 deletions(-) diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py index cb82cf8d71498..4244f5831aa50 100644 --- a/python/pyspark/pandas/typedef/typehints.py +++ b/python/pyspark/pandas/typedef/typehints.py @@ -296,7 +296,15 @@ def spark_type_to_pandas_dtype( elif isinstance(spark_type, (types.TimestampType, types.TimestampNTZType)): return np.dtype("datetime64[ns]") else: - return np.dtype(to_arrow_type(spark_type).to_pandas_dtype()) + from pyspark.pandas.utils import default_session + + prefers_large_var_types = ( + default_session() + .conf.get("spark.sql.execution.arrow.useLargeVarTypes", "false") + .lower() + == "true" + ) + return np.dtype(to_arrow_type(spark_type, prefers_large_var_types).to_pandas_dtype()) def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.DataType]: diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index d4363594a3153..d36baacb10a34 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -323,7 +323,7 @@ def convert_other(value: Any) -> Any: return lambda value: value @staticmethod - def convert(data: Sequence[Any], schema: StructType) -> "pa.Table": + def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool) -> "pa.Table": assert isinstance(data, list) and len(data) > 0 assert schema is not None and isinstance(schema, StructType) @@ -372,7 +372,8 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table": ) for field in schema.fields ] - ) + ), + prefers_large_types=use_large_var_types, ) return pa.Table.from_arrays(pylist, schema=pa_schema) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 59349a17886bb..c01c1e42a3185 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -506,9 +506,13 @@ def createDataFrame( "spark.sql.pyspark.inferNestedDictAsStruct.enabled", "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled", "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled", + "spark.sql.execution.arrow.useLargeVarTypes", ) timezone = configs["spark.sql.session.timeZone"] prefer_timestamp = configs["spark.sql.timestampType"] + prefers_large_types: bool = ( + cast(str, configs["spark.sql.execution.arrow.useLargeVarTypes"]).lower() == "true" + ) _table: Optional[pa.Table] = None @@ -552,7 +556,9 @@ def createDataFrame( if isinstance(schema, StructType): deduped_schema = cast(StructType, _deduplicate_field_names(schema)) spark_types = [field.dataType for field in deduped_schema.fields] - arrow_schema = to_arrow_schema(deduped_schema) + arrow_schema = to_arrow_schema( + deduped_schema, prefers_large_types=prefers_large_types + ) arrow_types = [field.type for field in arrow_schema] _cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()] elif isinstance(schema, DataType): @@ -570,7 +576,12 @@ def createDataFrame( else None for t in data.dtypes ] - arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types] + arrow_types = [ + to_arrow_type(dt, prefers_large_types=prefers_large_types) + if dt is not None + else None + for dt in spark_types + ] safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"] @@ -609,7 +620,13 @@ def createDataFrame( _table = ( _check_arrow_table_timestamps_localize(data, schema, True, timezone) - .cast(to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True)) + .cast( + to_arrow_schema( + schema, + error_on_duplicated_field_names_in_struct=True, + prefers_large_types=prefers_large_types, + ) + ) .rename_columns(schema.names) ) @@ -684,7 +701,7 @@ def createDataFrame( # Spark Connect will try its best to build the Arrow table with the # inferred schema in the client side, and then rename the columns and # cast the datatypes in the server side. - _table = LocalDataToArrowConversion.convert(_data, _schema) + _table = LocalDataToArrowConversion.convert(_data, _schema, prefers_large_types) # TODO: Beside the validation on number of columns, we should also check # whether the Arrow Schema is compatible with the user provided Schema. diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 172a4fc4b2343..18360fd813921 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -81,7 +81,7 @@ def toPandas(self) -> "PandasDataFrameLike": from pyspark.sql.pandas.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() - to_arrow_schema(self.schema) + to_arrow_schema(self.schema, prefers_large_types=jconf.arrowUseLargeVarTypes()) except Exception as e: if jconf.arrowPySparkFallbackEnabled(): msg = ( @@ -236,7 +236,12 @@ def toArrow(self) -> "pa.Table": from pyspark.sql.pandas.utils import require_minimum_pyarrow_version require_minimum_pyarrow_version() - schema = to_arrow_schema(self.schema, error_on_duplicated_field_names_in_struct=True) + prefers_large_var_types = jconf.arrowUseLargeVarTypes() + schema = to_arrow_schema( + self.schema, + error_on_duplicated_field_names_in_struct=True, + prefers_large_types=prefers_large_var_types, + ) import pyarrow as pa @@ -322,7 +327,8 @@ def _collect_as_arrow( from pyspark.sql.pandas.types import to_arrow_schema import pyarrow as pa - schema = to_arrow_schema(self.schema) + prefers_large_var_types = self.sparkSession._jconf.arrowUseLargeVarTypes() + schema = to_arrow_schema(self.schema, prefers_large_types=prefers_large_var_types) empty_arrays = [pa.array([], type=field.type) for field in schema] return [pa.RecordBatch.from_arrays(empty_arrays, schema=schema)] @@ -715,9 +721,16 @@ def _create_from_pandas_with_arrow( pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step)) # Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream + prefers_large_var_types = self._jconf.arrowUseLargeVarTypes() arrow_data = [ [ - (c, to_arrow_type(t) if t is not None else None, t) + ( + c, + to_arrow_type(t, prefers_large_types=prefers_large_var_types) + if t is not None + else None, + t, + ) for (_, c), t in zip(pdf_slice.items(), spark_types) ] for pdf_slice in pdf_slices @@ -785,8 +798,13 @@ def _create_from_arrow_table( if not isinstance(schema, StructType): schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=prefer_timestamp_ntz) + prefers_large_var_types = self._jconf.arrowUseLargeVarTypes() table = _check_arrow_table_timestamps_localize(table, schema, True, timezone).cast( - to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True) + to_arrow_schema( + schema, + error_on_duplicated_field_names_in_struct=True, + prefers_large_types=prefers_large_var_types, + ) ) # Chunk the Arrow Table into RecordBatches diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 536bf7307065c..cd2e1230418f3 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -793,6 +793,7 @@ def __init__( assign_cols_by_name, state_object_schema, arrow_max_records_per_batch, + prefers_large_var_types, ): super(ApplyInPandasWithStateSerializer, self).__init__( timezone, safecheck, assign_cols_by_name @@ -808,7 +809,9 @@ def __init__( ] ) - self.result_count_pdf_arrow_type = to_arrow_type(self.result_count_df_type) + self.result_count_pdf_arrow_type = to_arrow_type( + self.result_count_df_type, prefers_large_types=prefers_large_var_types + ) self.result_state_df_type = StructType( [ @@ -819,7 +822,9 @@ def __init__( ] ) - self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) + self.result_state_pdf_arrow_type = to_arrow_type( + self.result_state_df_type, prefers_large_types=prefers_large_var_types + ) self.arrow_max_records_per_batch = arrow_max_records_per_batch def load_stream(self, stream): diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index d65126bb3db9e..fcd70d4d18399 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -67,6 +67,7 @@ def to_arrow_type( dt: DataType, error_on_duplicated_field_names_in_struct: bool = False, timestamp_utc: bool = True, + prefers_large_types: bool = False, ) -> "pa.DataType": """ Convert Spark data type to PyArrow type @@ -107,8 +108,12 @@ def to_arrow_type( arrow_type = pa.float64() elif type(dt) == DecimalType: arrow_type = pa.decimal128(dt.precision, dt.scale) + elif type(dt) == StringType and prefers_large_types: + arrow_type = pa.large_string() elif type(dt) == StringType: arrow_type = pa.string() + elif type(dt) == BinaryType and prefers_large_types: + arrow_type = pa.large_binary() elif type(dt) == BinaryType: arrow_type = pa.binary() elif type(dt) == DateType: @@ -125,19 +130,34 @@ def to_arrow_type( elif type(dt) == ArrayType: field = pa.field( "element", - to_arrow_type(dt.elementType, error_on_duplicated_field_names_in_struct, timestamp_utc), + to_arrow_type( + dt.elementType, + error_on_duplicated_field_names_in_struct, + timestamp_utc, + prefers_large_types, + ), nullable=dt.containsNull, ) arrow_type = pa.list_(field) elif type(dt) == MapType: key_field = pa.field( "key", - to_arrow_type(dt.keyType, error_on_duplicated_field_names_in_struct, timestamp_utc), + to_arrow_type( + dt.keyType, + error_on_duplicated_field_names_in_struct, + timestamp_utc, + prefers_large_types, + ), nullable=False, ) value_field = pa.field( "value", - to_arrow_type(dt.valueType, error_on_duplicated_field_names_in_struct, timestamp_utc), + to_arrow_type( + dt.valueType, + error_on_duplicated_field_names_in_struct, + timestamp_utc, + prefers_large_types, + ), nullable=dt.valueContainsNull, ) arrow_type = pa.map_(key_field, value_field) @@ -152,7 +172,10 @@ def to_arrow_type( pa.field( field.name, to_arrow_type( - field.dataType, error_on_duplicated_field_names_in_struct, timestamp_utc + field.dataType, + error_on_duplicated_field_names_in_struct, + timestamp_utc, + prefers_large_types, ), nullable=field.nullable, ) @@ -163,7 +186,10 @@ def to_arrow_type( arrow_type = pa.null() elif isinstance(dt, UserDefinedType): arrow_type = to_arrow_type( - dt.sqlType(), error_on_duplicated_field_names_in_struct, timestamp_utc + dt.sqlType(), + error_on_duplicated_field_names_in_struct, + timestamp_utc, + prefers_large_types, ) elif type(dt) == VariantType: fields = [ @@ -185,6 +211,7 @@ def to_arrow_schema( schema: StructType, error_on_duplicated_field_names_in_struct: bool = False, timestamp_utc: bool = True, + prefers_large_types: bool = False, ) -> "pa.Schema": """ Convert a schema from Spark to Arrow @@ -212,7 +239,12 @@ def to_arrow_schema( fields = [ pa.field( field.name, - to_arrow_type(field.dataType, error_on_duplicated_field_names_in_struct, timestamp_utc), + to_arrow_type( + field.dataType, + error_on_duplicated_field_names_in_struct, + timestamp_utc, + prefers_large_types, + ), nullable=field.nullable, ) for field in schema diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py b/python/pyspark/sql/tests/arrow/test_arrow.py index a2ee113b6386e..065f97fcf7c78 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow.py +++ b/python/pyspark/sql/tests/arrow/test_arrow.py @@ -730,7 +730,11 @@ def test_createDataFrame_arrow_truncate_timestamp(self): def test_schema_conversion_roundtrip(self): from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema - arrow_schema = to_arrow_schema(self.schema) + arrow_schema = to_arrow_schema(self.schema, prefers_large_types=False) + schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True) + self.assertEqual(self.schema, schema_rt) + + arrow_schema = to_arrow_schema(self.schema, prefers_large_types=True) schema_rt = from_arrow_schema(arrow_schema, prefer_timestamp_ntz=True) self.assertEqual(self.schema, schema_rt) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e799498cdd80b..7bac0157caee1 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -117,10 +117,12 @@ def wrap_udf(f, args_offsets, kwargs_offsets, return_type): return args_kwargs_offsets, lambda *a: func(*a) -def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type): +def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) def verify_result_type(result): if not hasattr(result, "__len__"): @@ -159,7 +161,9 @@ def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type, runner_co func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow # optimization. @@ -205,8 +209,10 @@ def verify_result_length(result, length): ) -def wrap_pandas_batch_iter_udf(f, return_type): - arrow_return_type = to_arrow_type(return_type) +def wrap_pandas_batch_iter_udf(f, return_type, runner_conf): + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) iter_type_label = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" def verify_result(result): @@ -303,8 +309,10 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu ) -def wrap_arrow_batch_iter_udf(f, return_type): - arrow_return_type = to_arrow_type(return_type) +def wrap_arrow_batch_iter_udf(f, return_type, runner_conf): + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) def verify_result(result): if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): @@ -364,6 +372,7 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): + _use_large_var_types = use_large_var_types(runner_conf) _assign_cols_by_name = assign_cols_by_name(runner_conf) def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): @@ -384,7 +393,8 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se return result - return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), to_arrow_type(return_type))] + arrow_return_type = to_arrow_type(return_type, _use_large_var_types) + return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), arrow_return_type)] def verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types): @@ -482,10 +492,12 @@ def wrapped(key_table, value_table): return result.to_batches() - return lambda k, v: (wrapped(k, v), to_arrow_type(return_type)) + arrow_return_type = to_arrow_type(return_type, use_large_var_types(runner_conf)) + return lambda k, v: (wrapped(k, v), arrow_return_type) def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): + _use_large_var_types = use_large_var_types(runner_conf) _assign_cols_by_name = assign_cols_by_name(runner_conf) def wrapped(key_series, value_series): @@ -502,7 +514,8 @@ def wrapped(key_series, value_series): return result - return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] + arrow_return_type = to_arrow_type(return_type, _use_large_var_types) + return lambda k, v: [(wrapped(k, v), arrow_return_type)] def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf): @@ -517,7 +530,8 @@ def wrapped(stateful_processor_api_client, mode, key, value_series_gen): return result_iter - return lambda p, m, k, v: [(wrapped(p, m, k, v), to_arrow_type(return_type))] + arrow_return_type = to_arrow_type(return_type, use_large_var_types(runner_conf)) + return lambda p, m, k, v: [(wrapped(p, m, k, v), arrow_return_type)] def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, runner_conf): @@ -535,10 +549,11 @@ def wrapped(stateful_processor_api_client, mode, key, value_series_gen): return result_iter - return lambda p, m, k, v: [(wrapped(p, m, k, v), to_arrow_type(return_type))] + arrow_return_type = to_arrow_type(return_type, use_large_var_types(runner_conf)) + return lambda p, m, k, v: [(wrapped(p, m, k, v), arrow_return_type)] -def wrap_grouped_map_pandas_udf_with_state(f, return_type): +def wrap_grouped_map_pandas_udf_with_state(f, return_type, runner_conf): """ Provides a new lambda instance wrapping user function of applyInPandasWithState. @@ -553,6 +568,7 @@ def wrap_grouped_map_pandas_udf_with_state(f, return_type): Along with the returned iterator, the lambda instance will also produce the return_type as converted to the arrow schema. """ + _use_large_var_types = use_large_var_types(runner_conf) def wrapped(key_series, value_series_gen, state): """ @@ -627,13 +643,16 @@ def verify_element(result): state, ) - return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] + arrow_return_type = to_arrow_type(return_type, _use_large_var_types) + return lambda k, v, s: [(wrapped(k, v, s), arrow_return_type)] -def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type): +def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) def wrapped(*series): import pandas as pd @@ -653,9 +672,13 @@ def wrap_window_agg_pandas_udf( window_bound_types_str = runner_conf.get("pandas_window_bound_types") window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(",")][udf_index] if window_bound_type == "bounded": - return wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type) + return wrap_bounded_window_agg_pandas_udf( + f, args_offsets, kwargs_offsets, return_type, runner_conf + ) elif window_bound_type == "unbounded": - return wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type) + return wrap_unbounded_window_agg_pandas_udf( + f, args_offsets, kwargs_offsets, return_type, runner_conf + ) else: raise PySparkRuntimeError( errorClass="INVALID_WINDOW_BOUND_TYPE", @@ -665,14 +688,16 @@ def wrap_window_agg_pandas_udf( ) -def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type): +def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) # This is similar to grouped_agg_pandas_udf, the only difference # is that window_agg_pandas_udf needs to repeat the return value # to match window length, where grouped_agg_pandas_udf just returns # the scalar value. - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) def wrapped(*series): import pandas as pd @@ -686,12 +711,14 @@ def wrapped(*series): ) -def wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type): +def wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): # args_offsets should have at least 2 for begin_index, end_index. assert len(args_offsets) >= 2, len(args_offsets) func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets[2:], kwargs_offsets) - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) def wrapped(begin_index, end_index, *series): import pandas as pd @@ -865,15 +892,15 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, return_type) + return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: - return args_offsets, wrap_pandas_batch_iter_udf(func, return_type) + return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: - return args_offsets, wrap_pandas_batch_iter_udf(func, return_type) + return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: - return args_offsets, wrap_arrow_batch_iter_udf(func, return_type) + return args_offsets, wrap_arrow_batch_iter_udf(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) @@ -881,7 +908,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_grouped_map_arrow_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: - return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) + return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type, runner_conf) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: return args_offsets, wrap_grouped_transform_with_state_pandas_udf( func, return_type, runner_conf @@ -897,7 +924,9 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it return args_offsets, wrap_cogrouped_map_arrow_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: - return wrap_grouped_agg_pandas_udf(func, args_offsets, kwargs_offsets, return_type) + return wrap_grouped_agg_pandas_udf( + func, args_offsets, kwargs_offsets, return_type, runner_conf + ) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: return wrap_window_agg_pandas_udf( func, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index @@ -922,6 +951,10 @@ def assign_cols_by_name(runner_conf): ) +def use_large_var_types(runner_conf): + return runner_conf.get("spark.sql.execution.arrow.useLargeVarTypes", "false").lower() == "true" + + # Read and process a serialized user-defined table function (UDTF) from a socket. # It expects the UDTF to be in a specific format and performs various checks to # ensure the UDTF is valid. This function also prepares a mapper function for applying @@ -1254,7 +1287,9 @@ def check(row): def wrap_arrow_udtf(f, return_type): import pandas as pd - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) return_type_size = len(return_type) def verify_result(result): @@ -1499,6 +1534,7 @@ def read_udfs(pickleSer, infile, eval_type): # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) + prefers_large_var_types = use_large_var_types(runner_conf) safecheck = ( runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() == "true" @@ -1521,6 +1557,7 @@ def read_udfs(pickleSer, infile, eval_type): _assign_cols_by_name, state_object_schema, arrow_max_records_per_batch, + prefers_large_var_types, ) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: arrow_max_records_per_batch = runner_conf.get( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala index 76cd436b39b58..cb517c689ea16 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala @@ -56,6 +56,8 @@ private[sql] object SqlApiConf { val LEGACY_TIME_PARSER_POLICY_KEY: String = SqlApiConfHelper.LEGACY_TIME_PARSER_POLICY_KEY val CASE_SENSITIVE_KEY: String = SqlApiConfHelper.CASE_SENSITIVE_KEY val SESSION_LOCAL_TIMEZONE_KEY: String = SqlApiConfHelper.SESSION_LOCAL_TIMEZONE_KEY + val ARROW_EXECUTION_USE_LARGE_VAR_TYPES: String = + SqlApiConfHelper.ARROW_EXECUTION_USE_LARGE_VAR_TYPES val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = { SqlApiConfHelper.LOCAL_RELATION_CACHE_THRESHOLD_KEY } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala index 13ef13e5894e0..486a7dfb58dd0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala @@ -33,6 +33,7 @@ private[sql] object SqlApiConfHelper { val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone" val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = "spark.sql.session.localRelationCacheThreshold" val DEFAULT_COLLATION: String = "spark.sql.session.collation.default" + val ARROW_EXECUTION_USE_LARGE_VAR_TYPES = "spark.sql.execution.arrow.useLargeVarTypes" val confGetter: AtomicReference[() => SqlApiConf] = { new AtomicReference[() => SqlApiConf](() => DefaultSqlApiConf) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 55d1aff8261d4..587ca43e57300 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -202,7 +202,7 @@ private[sql] object ArrowUtils { schema: StructType, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean, - largeVarTypes: Boolean = false): Schema = { + largeVarTypes: Boolean): Schema = { new Schema(schema.map { field => toArrowField( field.name, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 065b4b8c821a6..c496b0e82c263 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -33,8 +33,10 @@ object ArrowWriter { def create( schema: StructType, timeZoneId: String, - errorOnDuplicatedFieldNames: Boolean = true): ArrowWriter = { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) + errorOnDuplicatedFieldNames: Boolean = true, + largeVarTypes: Boolean = false): ArrowWriter = { + val arrowSchema = ArrowUtils.toArrowSchema( + schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) create(root) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2211165ad6e50..d37e33b5adcdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3436,9 +3436,8 @@ object SQLConf { .doc("When using Apache Arrow, use large variable width vectors for string and binary " + "types. Regular string and binary types have a 2GiB limit for a column in a single " + "record batch. Large variable types remove this limitation at the cost of higher memory " + - "usage per value. Note that this only works for DataFrame.mapInArrow.") + "usage per value.") .version("3.5.0") - .internal() .booleanConf .createWithDefault(false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala index c705a6b791bd1..7124c94b390d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala @@ -30,7 +30,8 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtrip(dt: DataType): Unit = { dt match { case schema: StructType => - assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null, true)) === schema) + assert(ArrowUtils.fromArrowSchema( + ArrowUtils.toArrowSchema(schema, null, true, false)) === schema) case _ => roundtrip(new StructType().add("value", dt)) } @@ -69,7 +70,7 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtripWithTz(timeZoneId: String): Unit = { val schema = new StructType().add("value", TimestampType) - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, true) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, true, false) val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] assert(fieldType.getTimezone() === timeZoneId) assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) @@ -105,9 +106,9 @@ class ArrowUtilsSuite extends SparkFunSuite { def check(dt: DataType, expected: DataType): Unit = { val schema = new StructType().add("value", dt) intercept[SparkUnsupportedOperationException] { - ArrowUtils.toArrowSchema(schema, null, true) + ArrowUtils.toArrowSchema(schema, null, true, false) } - assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null, false)) + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null, false, false)) === new StructType().add("value", expected)) } diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala index 2791c6b6add55..c7b4748f12221 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SQLImplicitsTestSuite.scala @@ -64,7 +64,8 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { input = Iterator.single(expected), enc = encoder, allocator = allocator, - timeZoneId = "UTC") + timeZoneId = "UTC", + largeVarTypes = false) val fromArrow = ArrowDeserializers.deserializeFromArrow( input = Iterator.single(batch.toByteArray), encoder = encoder, diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index f6662b3351ba7..58e19389cae2e 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -106,7 +106,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { maxRecordsPerBatch = maxRecordsPerBatch, maxBatchSize = maxBatchSize, batchSizeCheckInterval = batchSizeCheckInterval, - timeZoneId = "UTC") + timeZoneId = "UTC", + largeVarTypes = false) val inspectedIterator = if (inspectBatch != null) { arrowIterator.map { batch => @@ -183,7 +184,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { allocator, maxRecordsPerBatch = 1024, maxBatchSize = 8 * 1024, - timeZoneId = "UTC") + timeZoneId = "UTC", + largeVarTypes = false) } private def compareIterators[T](expected: Iterator[T], actual: Iterator[T]): Unit = { @@ -626,7 +628,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { allocator, maxRecordsPerBatch = 128, maxBatchSize = 1024, - timeZoneId = "UTC") + timeZoneId = "UTC", + largeVarTypes = false) intercept[NullPointerException] { iterator.next() } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala index b0d7d6aa5d134..e4ac4a1ba6199 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect import java.net.URI import java.nio.file.{Files, Paths} +import java.util.Locale import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong @@ -110,7 +111,8 @@ class SparkSession private[sql] ( private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = { newDataset(encoder) { builder => if (data.nonEmpty) { - val arrowData = ArrowSerializer.serialize(data, encoder, allocator, timeZoneId) + val arrowData = + ArrowSerializer.serialize(data, encoder, allocator, timeZoneId, largeVarTypes) if (arrowData.size() <= conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt) { builder.getLocalRelationBuilder .setSchema(encoder.schema.json) @@ -467,6 +469,8 @@ class SparkSession private[sql] ( } private[sql] def timeZoneId: String = conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY) + private[sql] def largeVarTypes: Boolean = + conf.get(SqlApiConf.ARROW_EXECUTION_USE_LARGE_VAR_TYPES).toLowerCase(Locale.ROOT).toBoolean private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = { val value = executeInternal(plan) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index c01390bf07857..584a318f039d8 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import com.google.protobuf.ByteString import org.apache.arrow.memory.BufferAllocator -import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector._ import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel} import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer} @@ -50,8 +50,10 @@ import org.apache.spark.unsafe.types.VariantVal class ArrowSerializer[T]( private[this] val enc: AgnosticEncoder[T], private[this] val allocator: BufferAllocator, - private[this] val timeZoneId: String) { - private val (root, serializer) = ArrowSerializer.serializerFor(enc, allocator, timeZoneId) + private[this] val timeZoneId: String, + private[this] val largeVarTypes: Boolean) { + private val (root, serializer) = + ArrowSerializer.serializerFor(enc, allocator, timeZoneId, largeVarTypes) private val vectors = root.getFieldVectors.asScala private val unloader = new VectorUnloader(root) private val schemaBytes = { @@ -144,12 +146,13 @@ object ArrowSerializer { maxRecordsPerBatch: Int, maxBatchSize: Long, timeZoneId: String, + largeVarTypes: Boolean, batchSizeCheckInterval: Int = 128): CloseableIterator[Array[Byte]] = { assert(maxRecordsPerBatch > 0) assert(maxBatchSize > 0) assert(batchSizeCheckInterval > 0) new CloseableIterator[Array[Byte]] { - private val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId) + private val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId, largeVarTypes) private val bytes = new ByteArrayOutputStream private var hasWrittenFirstBatch = false @@ -191,8 +194,9 @@ object ArrowSerializer { input: Iterator[T], enc: AgnosticEncoder[T], allocator: BufferAllocator, - timeZoneId: String): ByteString = { - val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId) + timeZoneId: String, + largeVarTypes: Boolean): ByteString = { + val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId, largeVarTypes) try { input.foreach(serializer.append) val output = ByteString.newOutput() @@ -211,9 +215,14 @@ object ArrowSerializer { def serializerFor[T]( encoder: AgnosticEncoder[T], allocator: BufferAllocator, - timeZoneId: String): (VectorSchemaRoot, Serializer) = { + timeZoneId: String, + largeVarTypes: Boolean): (VectorSchemaRoot, Serializer) = { val arrowSchema = - ArrowUtils.toArrowSchema(encoder.schema, timeZoneId, errorOnDuplicatedFieldNames = true) + ArrowUtils.toArrowSchema( + encoder.schema, + timeZoneId, + errorOnDuplicatedFieldNames = true, + largeVarTypes = largeVarTypes) val root = VectorSchemaRoot.create(arrowSchema, allocator) val serializer = if (encoder.schema != encoder.dataType) { assert(root.getSchema.getFields.size() == 1) @@ -264,19 +273,36 @@ object ArrowSerializer { new FieldSerializer[String, VarCharVector](v) { override def set(index: Int, value: String): Unit = setString(v, index, value) } + case (StringEncoder, v: LargeVarCharVector) => + new FieldSerializer[String, LargeVarCharVector](v) { + override def set(index: Int, value: String): Unit = setString(v, index, value) + } case (JavaEnumEncoder(_), v: VarCharVector) => new FieldSerializer[Enum[_], VarCharVector](v) { override def set(index: Int, value: Enum[_]): Unit = setString(v, index, value.name()) } + case (JavaEnumEncoder(_), v: LargeVarCharVector) => + new FieldSerializer[Enum[_], LargeVarCharVector](v) { + override def set(index: Int, value: Enum[_]): Unit = setString(v, index, value.name()) + } case (ScalaEnumEncoder(_, _), v: VarCharVector) => new FieldSerializer[Enumeration#Value, VarCharVector](v) { override def set(index: Int, value: Enumeration#Value): Unit = setString(v, index, value.toString) } + case (ScalaEnumEncoder(_, _), v: LargeVarCharVector) => + new FieldSerializer[Enumeration#Value, LargeVarCharVector](v) { + override def set(index: Int, value: Enumeration#Value): Unit = + setString(v, index, value.toString) + } case (BinaryEncoder, v: VarBinaryVector) => new FieldSerializer[Array[Byte], VarBinaryVector](v) { override def set(index: Int, value: Array[Byte]): Unit = vector.setSafe(index, value) } + case (BinaryEncoder, v: LargeVarBinaryVector) => + new FieldSerializer[Array[Byte], LargeVarBinaryVector](v) { + override def set(index: Int, value: Array[Byte]): Unit = vector.setSafe(index, value) + } case (SparkDecimalEncoder(_), v: DecimalVector) => new FieldSerializer[Decimal, DecimalVector](v) { override def set(index: Int, value: Decimal): Unit = @@ -477,7 +503,7 @@ object ArrowSerializer { private val methodLookup = MethodHandles.lookup() - private def setString(vector: VarCharVector, index: Int, string: String): Unit = { + private def setString(vector: VariableWidthFieldVector, index: Int, string: String): Unit = { val bytes = Text.encode(string) vector.setSafe(index, bytes, 0, bytes.limit()) } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala index 53d8d46e62689..3dbfce18e7b48 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala @@ -20,7 +20,7 @@ import java.math.{BigDecimal => JBigDecimal} import java.sql.{Date, Timestamp} import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffset} -import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector} +import org.apache.arrow.vector._ import org.apache.arrow.vector.util.Text import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils, SparkStringUtils, TimestampFormatter} @@ -82,7 +82,9 @@ object ArrowVectorReader { case v: Float8Vector => new Float8VectorReader(v) case v: DecimalVector => new DecimalVectorReader(v) case v: VarCharVector => new VarCharVectorReader(v) + case v: LargeVarCharVector => new LargeVarCharVectorReader(v) case v: VarBinaryVector => new VarBinaryVectorReader(v) + case v: LargeVarBinaryVector => new LargeVarBinaryVectorReader(v) case v: DurationVector => new DurationVectorReader(v) case v: IntervalYearVector => new IntervalYearVectorReader(v) case v: DateDayVector => new DateDayVectorReader(v, timeZoneId) @@ -189,12 +191,29 @@ private[arrow] class VarCharVectorReader(v: VarCharVector) override def getString(i: Int): String = Text.decode(vector.get(i)) } +private[arrow] class LargeVarCharVectorReader(v: LargeVarCharVector) + extends TypedArrowVectorReader[LargeVarCharVector](v) { + // This is currently a bit heavy on allocations: + // - byte array created in VarCharVector.get + // - CharBuffer created CharSetEncoder + // - char array in String + // By using direct buffers and reusing the char buffer + // we could get rid of the first two allocations. + override def getString(i: Int): String = Text.decode(vector.get(i)) +} + private[arrow] class VarBinaryVectorReader(v: VarBinaryVector) extends TypedArrowVectorReader[VarBinaryVector](v) { override def getBytes(i: Int): Array[Byte] = vector.get(i) override def getString(i: Int): String = SparkStringUtils.getHexString(getBytes(i)) } +private[arrow] class LargeVarBinaryVectorReader(v: LargeVarBinaryVector) + extends TypedArrowVectorReader[LargeVarBinaryVector](v) { + override def getBytes(i: Int): Array[Byte] = vector.get(i) + override def getString(i: Int): String = SparkStringUtils.getHexString(getBytes(i)) +} + private[arrow] class DurationVectorReader(v: DurationVector) extends TypedArrowVectorReader[DurationVector](v) { override def getDuration(i: Int): Duration = vector.getObject(i) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 497576a6630d3..fc3f180634167 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -90,14 +90,16 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) maxRecordsPerBatch: Int, maxBatchSize: Long, timeZoneId: String, - errorOnDuplicatedFieldNames: Boolean): Iterator[InternalRow] => Iterator[Batch] = { rows => + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean): Iterator[InternalRow] => Iterator[Batch] = { rows => val batches = ArrowConverters.toBatchWithSchemaIterator( rows, schema, maxRecordsPerBatch, maxBatchSize, timeZoneId, - errorOnDuplicatedFieldNames) + errorOnDuplicatedFieldNames, + largeVarTypes) batches.map(b => b -> batches.rowCountInLastBatch) } @@ -110,6 +112,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) val schema = dataframe.schema val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + val largeVarTypes = spark.sessionState.conf.arrowUseLargeVarTypes // Conservatively sets it 70% because the size is not accurate but estimated. val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong @@ -118,7 +121,8 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) maxRecordsPerBatch, maxBatchSize, timeZoneId, - errorOnDuplicatedFieldNames = false) + errorOnDuplicatedFieldNames = false, + largeVarTypes = largeVarTypes) var numSent = 0 def sendBatch(bytes: Array[Byte], count: Long, startOffset: Long): Unit = { @@ -239,7 +243,8 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) ArrowConverters.createEmptyArrowBatch( schema, timeZoneId, - errorOnDuplicatedFieldNames = false), + errorOnDuplicatedFieldNames = false, + largeVarTypes = largeVarTypes), 0L, 0L) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 6019969be05b9..2a296522b620e 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2550,7 +2550,8 @@ class SparkConnectPlanner( result.iterator, StringEncoder, ArrowUtils.rootAllocator, - session.sessionState.conf.sessionLocalTimeZone) + session.sessionState.conf.sessionLocalTimeZone, + session.sessionState.conf.arrowUseLargeVarTypes) val sqlCommandResult = SqlCommandResult.newBuilder() sqlCommandResult.getRelationBuilder.getLocalRelationBuilder.setData(arrowData) responseObserver.onNext( @@ -2613,13 +2614,15 @@ class SparkConnectPlanner( val schema = df.schema val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong val timeZoneId = session.sessionState.conf.sessionLocalTimeZone + val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes // Convert the data. val bytes = if (rows.isEmpty) { ArrowConverters.createEmptyArrowBatch( schema, timeZoneId, - errorOnDuplicatedFieldNames = false) + errorOnDuplicatedFieldNames = false, + largeVarTypes = largeVarTypes) } else { val batches = ArrowConverters.toBatchWithSchemaIterator( rowIter = rows.iterator, @@ -2627,7 +2630,8 @@ class SparkConnectPlanner( maxRecordsPerBatch = -1, maxEstimatedBatchSize = maxBatchSize, timeZoneId = timeZoneId, - errorOnDuplicatedFieldNames = false) + errorOnDuplicatedFieldNames = false, + largeVarTypes = largeVarTypes) assert(batches.hasNext) val bytes = batches.next() assert(!batches.hasNext, s"remaining batches: ${batches.size}") diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index db430549818de..76c88d515ec0b 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -146,7 +146,12 @@ trait SparkConnectServerTest extends SharedSparkSession { protected def buildLocalRelation[A <: Product: TypeTag](data: Seq[A]) = { val encoder = ScalaReflection.encoderFor[A] val arrowData = - ArrowSerializer.serialize(data.iterator, encoder, allocator, TimeZone.getDefault.getID) + ArrowSerializer.serialize( + data.iterator, + encoder, + allocator, + TimeZone.getDefault.getID, + largeVarTypes = false) val localRelation = proto.LocalRelation .newBuilder() .setData(arrowData) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 2a09d5f8e8bd5..72f7065b44240 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -101,7 +101,8 @@ trait SparkConnectPlanTest extends SharedSparkSession { Long.MaxValue, Long.MaxValue, timeZoneId, - true) + true, + false) .next() localRelationBuilder.setData(ByteString.copyFrom(bytes)) @@ -478,7 +479,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { test("Empty ArrowBatch") { val schema = StructType(Seq(StructField("int", IntegerType))) - val data = ArrowConverters.createEmptyArrowBatch(schema, null, true) + val data = ArrowConverters.createEmptyArrowBatch(schema, null, true, false) val localRelation = proto.Relation .newBuilder() .setLocalRelation( diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 2bbd6863b1105..494aceb2fb587 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -1115,7 +1115,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { Long.MaxValue, Long.MaxValue, null, - true) + true, + false) .next() proto.Relation .newBuilder() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index ad58fc0c2fcf3..1efd8f9e32208 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -249,7 +249,13 @@ private[sql] object SQLUtils extends Logging { val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, true, context) + ArrowConverters.fromBatchIterator( + iter, + schema, + timeZoneId, + true, + false, + context) } sparkSession.internalCreateDataFrame(rdd.setName("arrow"), schema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index d78a3a391edb6..8930b5895d320 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -2089,7 +2089,8 @@ class Dataset[T] private[sql]( val buffer = new ByteArrayOutputStream() val out = new DataOutputStream(outputStream) val batchWriter = - new ArrowBatchStreamWriter(schema, buffer, timeZoneId, errorOnDuplicatedFieldNames = true) + new ArrowBatchStreamWriter( + schema, buffer, timeZoneId, errorOnDuplicatedFieldNames = true, largeVarTypes = false) val arrowBatchRdd = toArrowBatchRdd(plan) val numPartitions = arrowBatchRdd.partitions.length @@ -2140,12 +2141,14 @@ class Dataset[T] private[sql]( val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone val errorOnDuplicatedFieldNames = sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" + val largeVarTypes = sparkSession.sessionState.conf.arrowUseLargeVarTypes PythonRDD.serveToStream("serve-Arrow") { outputStream => withAction("collectAsArrowToPython", queryExecution) { plan => val out = new DataOutputStream(outputStream) val batchWriter = - new ArrowBatchStreamWriter(schema, out, timeZoneId, errorOnDuplicatedFieldNames) + new ArrowBatchStreamWriter( + schema, out, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) // Batches ordered by (index of partition, batch index in that partition) tuple val batchOrder = ArrayBuffer.empty[(Int, Int)] @@ -2294,10 +2297,18 @@ class Dataset[T] private[sql]( val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone val errorOnDuplicatedFieldNames = sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" + val largeVarTypes = + sparkSession.sessionState.conf.arrowUseLargeVarTypes plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toBatchIterator( - iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, context) + iter, + schemaCaptured, + maxRecordsPerBatch, + timeZoneId, + errorOnDuplicatedFieldNames, + largeVarTypes, + context) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 7ea1bd6ff7dc6..ed490347ae821 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -51,9 +51,11 @@ private[sql] class ArrowBatchStreamWriter( schema: StructType, out: OutputStream, timeZoneId: String, - errorOnDuplicatedFieldNames: Boolean) { + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean) { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) + val arrowSchema = ArrowUtils.toArrowSchema( + schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) val writeChannel = new WriteChannel(Channels.newChannel(out)) // Write the Arrow schema first, before batches @@ -81,10 +83,11 @@ private[sql] object ArrowConverters extends Logging { maxRecordsPerBatch: Long, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean, context: TaskContext) extends Iterator[Array[Byte]] with AutoCloseable { protected val arrowSchema = - ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) private val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) @@ -137,9 +140,16 @@ private[sql] object ArrowConverters extends Logging { maxEstimatedBatchSize: Long, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean, context: TaskContext) extends ArrowBatchIterator( - rowIter, schema, maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, context) { + rowIter, + schema, + maxRecordsPerBatch, + timeZoneId, + errorOnDuplicatedFieldNames, + largeVarTypes, + context) { private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) var rowCountInLastBatch: Long = 0 @@ -205,9 +215,16 @@ private[sql] object ArrowConverters extends Logging { maxRecordsPerBatch: Long, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean, context: TaskContext): ArrowBatchIterator = { new ArrowBatchIterator( - rowIter, schema, maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, context) + rowIter, + schema, + maxRecordsPerBatch, + timeZoneId, + errorOnDuplicatedFieldNames, + largeVarTypes, + context) } /** @@ -220,19 +237,21 @@ private[sql] object ArrowConverters extends Logging { maxRecordsPerBatch: Long, maxEstimatedBatchSize: Long, timeZoneId: String, - errorOnDuplicatedFieldNames: Boolean): ArrowBatchWithSchemaIterator = { + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean): ArrowBatchWithSchemaIterator = { new ArrowBatchWithSchemaIterator( rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize, - timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get()) + timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes, TaskContext.get()) } private[sql] def createEmptyArrowBatch( schema: StructType, timeZoneId: String, - errorOnDuplicatedFieldNames: Boolean): Array[Byte] = { + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean): Array[Byte] = { val batches = new ArrowBatchWithSchemaIterator( Iterator.empty, schema, 0L, 0L, - timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get()) { + timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes, TaskContext.get()) { override def hasNext: Boolean = true } Utils.tryWithSafeFinally { @@ -299,12 +318,13 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean, context: TaskContext) extends InternalRowIterator(arrowBatchIter, context) { override def nextBatch(): (Iterator[InternalRow], StructType) = { val arrowSchema = - ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) val root = VectorSchemaRoot.create(arrowSchema, allocator) resources.append(root) val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) @@ -344,9 +364,12 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean, - context: TaskContext): Iterator[InternalRow] = new InternalRowIteratorWithoutSchema( - arrowBatchIter, schema, timeZoneId, errorOnDuplicatedFieldNames, context - ) + largeVarTypes: Boolean, + context: TaskContext): Iterator[InternalRow] = { + new InternalRowIteratorWithoutSchema( + arrowBatchIter, schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes, context + ) + } /** * Maps iterator from serialized ArrowRecordBatches to InternalRows. Different from @@ -393,6 +416,7 @@ private[sql] object ArrowConverters extends Logging { val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] val attrs = toAttributes(schema) val batchesInDriver = arrowBatches.toArray + val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes val shouldUseRDD = session.sessionState.conf .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum @@ -407,6 +431,7 @@ private[sql] object ArrowConverters extends Logging { schema, timezone, errorOnDuplicatedFieldNames = false, + largeVarTypes = largeVarTypes, TaskContext.get()) } session.internalCreateDataFrame(rdd.setName("arrow"), schema) @@ -417,6 +442,7 @@ private[sql] object ArrowConverters extends Logging { schema, session.sessionState.conf.sessionLocalTimeZone, errorOnDuplicatedFieldNames = false, + largeVarTypes = largeVarTypes, TaskContext.get()) // Project/copy it. Otherwise, the Arrow column vectors will be closed and released out. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 1bddd81fbfe20..bf21424225621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -120,6 +120,9 @@ object ArrowPythonRunner { val arrowAyncParallelism = conf.pythonUDFArrowConcurrencyLevel.map(v => Seq(SQLConf.PYTHON_UDF_ARROW_CONCURRENCY_LEVEL.key -> v.toString) ).getOrElse(Seq.empty) - Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++ arrowAyncParallelism: _*) + val useLargeVarTypes = Seq(SQLConf.ARROW_EXECUTION_USE_LARGE_VAR_TYPES.key -> + conf.arrowUseLargeVarTypes.toString) + Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++ + arrowAyncParallelism ++ useLargeVarTypes: _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 59e8970b9c9b6..9caa344d00c58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -45,6 +45,7 @@ class CoGroupedArrowPythonRunner( leftSchema: StructType, rightSchema: StructType, timeZoneId: String, + largeVarTypes: Boolean, conf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], @@ -109,7 +110,8 @@ class CoGroupedArrowPythonRunner( dataOut: DataOutputStream, name: String): Unit = { val arrowSchema = - ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true) + ArrowUtils.toArrowSchema( + schema, timeZoneId, errorOnDuplicatedFieldNames = true, largeVarTypes = largeVarTypes) val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec ($name)", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala index 66ed2bca76775..af487218391e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInBatchExec.scala @@ -42,6 +42,7 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with BinaryExecNode with Pyth protected val pythonEvalType: Int private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val largeVarTypes = conf.arrowUseLargeVarTypes private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) private val pythonUDF = func.asInstanceOf[PythonUDF] private val pandasFunction = pythonUDF.func @@ -84,6 +85,7 @@ trait FlatMapCoGroupsInBatchExec extends SparkPlan with BinaryExecNode with Pyth DataTypeUtils.fromAttributes(leftDedup), DataTypeUtils.fromAttributes(rightDedup), sessionLocalTimeZone, + largeVarTypes, pythonRunnerConf, pythonMetrics, jobArtifactUUID, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala index 45ecf87009505..aaf2f256273d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala @@ -85,7 +85,8 @@ class ArrowRRunner( override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = { if (inputIterator.hasNext) { val arrowSchema = - ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true) + ArrowUtils.toArrowSchema( + schema, timeZoneId, errorOnDuplicatedFieldNames = true, largeVarTypes = false) val allocator = ArrowUtils.rootAllocator.newChildAllocator( "stdout writer for R", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 33e5d46ee2333..39c3d8df7550e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1376,8 +1376,10 @@ class ArrowConvertersSuite extends SharedSparkSession { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, true, ctx) - val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, true, ctx) + val batchIter = ArrowConverters.toBatchIterator( + inputRows.iterator, schema, 5, null, true, false, ctx) + val outputRowIter = ArrowConverters.fromBatchIterator( + batchIter, schema, null, true, false, ctx) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => @@ -1397,12 +1399,13 @@ class ArrowConvertersSuite extends SharedSparkSession { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val batchIter = ArrowConverters.toBatchIterator(inputRows.iterator, schema, 5, null, true, ctx) + val batchIter = ArrowConverters.toBatchIterator( + inputRows.iterator, schema, 5, null, true, false, ctx) // Write batches to Arrow stream format as a byte array val out = new ByteArrayOutputStream() Utils.tryWithResource(new DataOutputStream(out)) { dataOut => - val writer = new ArrowBatchStreamWriter(schema, dataOut, null, true) + val writer = new ArrowBatchStreamWriter(schema, dataOut, null, true, false) writer.writeBatches(batchIter) writer.end() } @@ -1410,7 +1413,8 @@ class ArrowConvertersSuite extends SharedSparkSession { // Read Arrow stream into batches, then convert back to rows val in = new ByteArrayReadableSeekableByteChannel(out.toByteArray) val readBatches = ArrowConverters.getBatchesFromStream(in) - val outputRowIter = ArrowConverters.fromBatchIterator(readBatches, schema, null, true, ctx) + val outputRowIter = ArrowConverters.fromBatchIterator( + readBatches, schema, null, true, false, ctx) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => @@ -1441,7 +1445,7 @@ class ArrowConvertersSuite extends SharedSparkSession { } val ctx = TaskContext.empty() val batchIter = ArrowConverters.toBatchWithSchemaIterator( - inputRows.iterator, schema, 5, 1024 * 1024, null, true) + inputRows.iterator, schema, 5, 1024 * 1024, null, true, false) val (outputRowIter, outputType) = ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx) var count = 0 @@ -1460,7 +1464,8 @@ class ArrowConvertersSuite extends SharedSparkSession { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() val batchIter = - ArrowConverters.toBatchWithSchemaIterator(Iterator.empty, schema, 5, 1024 * 1024, null, true) + ArrowConverters.toBatchWithSchemaIterator( + Iterator.empty, schema, 5, 1024 * 1024, null, true, false) val (outputRowIter, outputType) = ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx) assert(0 == outputRowIter.length) @@ -1474,7 +1479,7 @@ class ArrowConvertersSuite extends SharedSparkSession { proj(row).copy() } val batchIter1 = ArrowConverters.toBatchWithSchemaIterator( - inputRows1.iterator, schema1, 5, 1024 * 1024, null, true) + inputRows1.iterator, schema1, 5, 1024 * 1024, null, true, false) val schema2 = StructType(Seq(StructField("field2", IntegerType, nullable = true))) val inputRows2 = Array(InternalRow(1)).map { row => @@ -1482,7 +1487,7 @@ class ArrowConvertersSuite extends SharedSparkSession { proj(row).copy() } val batchIter2 = ArrowConverters.toBatchWithSchemaIterator( - inputRows2.iterator, schema2, 5, 1024 * 1024, null, true) + inputRows2.iterator, schema2, 5, 1024 * 1024, null, true, false) val iter = batchIter1.toArray ++ batchIter2 @@ -1511,11 +1516,13 @@ class ArrowConvertersSuite extends SharedSparkSession { batchBytes: Array[Byte], jsonFile: File, timeZoneId: String = null, - errorOnDuplicatedFieldNames: Boolean = true): Unit = { + errorOnDuplicatedFieldNames: Boolean = true, + largeVarTypes: Boolean = false): Unit = { val allocator = new RootAllocator(Long.MaxValue) val jsonReader = new JsonFileReader(jsonFile, allocator) - val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId, errorOnDuplicatedFieldNames) + val arrowSchema = ArrowUtils.toArrowSchema( + sparkSchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 7830ea1da1774..acf258a373c36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -158,7 +158,8 @@ class ArrowWriterSuite extends SparkFunSuite { schema: StructType, timeZoneId: String): (ArrowWriter, Int) = { val arrowSchema = - ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames = true) + ArrowUtils.toArrowSchema( + schema, timeZoneId, errorOnDuplicatedFieldNames = true, largeVarTypes = false) val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) val vector = root.getFieldVectors.get(0) vector.allocateNew()