From 7f890ddace3221747ae93e6fad4efbb1f70109cf Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Wed, 12 Feb 2025 20:53:55 +0900 Subject: [PATCH] Fix HybridParquetScan over select(1) Signed-off-by: sperlingxx --- .../src/main/python/hybrid_parquet_test.py | 23 +++++++++++++------ .../rapids/hybrid/HybridExecutionUtils.scala | 6 +++-- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/integration_tests/src/main/python/hybrid_parquet_test.py b/integration_tests/src/main/python/hybrid_parquet_test.py index 977a7d3ab0c..e6e28e60e81 100644 --- a/integration_tests/src/main/python/hybrid_parquet_test.py +++ b/integration_tests/src/main/python/hybrid_parquet_test.py @@ -70,6 +70,7 @@ [StructGen([["c0", simple_string_to_string_map_gen]])], [StructGen([["c0", ArrayGen(simple_string_to_string_map_gen)]])], [StructGen([["c0", StructGen([["cc0", simple_string_to_string_map_gen]])]])], + [], ] @@ -132,21 +133,29 @@ def test_hybrid_parquet_read_round_trip_multiple_batches(spark_tmp_path, @pytest.mark.parametrize('parquet_gens', parquet_gens_fallback_lists, ids=idfn) @hybrid_test def test_hybrid_parquet_read_fallback_to_gpu(spark_tmp_path, parquet_gens): - gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' - with_cpu_session( - lambda spark: gen_df(spark, gen_list, length=512).write.parquet(data_path), - conf=rebase_write_corrected_conf) - + # check the fallback over empty schema(`SELECT COUNT(1)`) within the same case + if len(parquet_gens) == 0: + with_cpu_session( + lambda spark: gen_df(spark, [('a', int_gen)], length=512).write.parquet(data_path), + conf=rebase_write_corrected_conf) + read_fn = lambda spark: spark.read.parquet(data_path).selectExpr('count(1)') + else: + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] + with_cpu_session( + lambda spark: gen_df(spark, gen_list, length=512).write.parquet(data_path), + conf=rebase_write_corrected_conf) + read_fn = lambda spark: spark.read.parquet(data_path) assert_cpu_and_gpu_are_equal_collect_with_capture( - lambda spark: spark.read.parquet(data_path), + read_fn, exist_classes='GpuFileSourceScanExec', non_exist_classes='HybridFileSourceScanExec', conf={ 'spark.sql.sources.useV1SourceList': 'parquet', 'spark.rapids.sql.parquet.useHybridReader': 'true', }) - + + filter_split_conf = { 'spark.sql.sources.useV1SourceList': 'parquet', 'spark.rapids.sql.parquet.useHybridReader': 'true', diff --git a/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/HybridExecutionUtils.scala b/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/HybridExecutionUtils.scala index 99dd09fec24..706384e3ab8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/HybridExecutionUtils.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/rapids/hybrid/HybridExecutionUtils.scala @@ -55,7 +55,9 @@ object HybridExecutionUtils extends PredicateHelper { false } // Currently, only support reading Parquet - lazy val isParquet = fsse.relation.fileFormat.getClass == classOf[ParquetFileFormat] + val isParquet = fsse.relation.fileFormat.getClass == classOf[ParquetFileFormat] + // Fallback to GpuScan over the `select count(1)` cases + val nonEmptySchema = fsse.output.nonEmpty && fsse.requiredSchema.nonEmpty // Check if data types of all fields are supported by HybridParquetReader lazy val allSupportedTypes = !fsse.requiredSchema.exists { field => TrampolineUtil.dataTypeExistsRecursively(field.dataType, { @@ -78,7 +80,7 @@ object HybridExecutionUtils extends PredicateHelper { // TODO: supports BucketedScan lazy val noBucketedScan = !fsse.bucketedScan - isEnabled && isParquet && allSupportedTypes && noBucketedScan + isEnabled && isParquet && nonEmptySchema && allSupportedTypes && noBucketedScan } /**