Skip to content

Commit

Permalink
Fix HybridParquetScan over select(1)
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <lovedreamf@gmail.com>
  • Loading branch information
sperlingxx committed Feb 12, 2025
1 parent 2cc2992 commit 7f890dd
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
23 changes: 16 additions & 7 deletions integration_tests/src/main/python/hybrid_parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
[StructGen([["c0", simple_string_to_string_map_gen]])],
[StructGen([["c0", ArrayGen(simple_string_to_string_map_gen)]])],
[StructGen([["c0", StructGen([["cc0", simple_string_to_string_map_gen]])]])],
[],
]


Expand Down Expand Up @@ -132,21 +133,29 @@ def test_hybrid_parquet_read_round_trip_multiple_batches(spark_tmp_path,
@pytest.mark.parametrize('parquet_gens', parquet_gens_fallback_lists, ids=idfn)
@hybrid_test
def test_hybrid_parquet_read_fallback_to_gpu(spark_tmp_path, parquet_gens):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark: gen_df(spark, gen_list, length=512).write.parquet(data_path),
conf=rebase_write_corrected_conf)

# check the fallback over empty schema(`SELECT COUNT(1)`) within the same case
if len(parquet_gens) == 0:
with_cpu_session(
lambda spark: gen_df(spark, [('a', int_gen)], length=512).write.parquet(data_path),
conf=rebase_write_corrected_conf)
read_fn = lambda spark: spark.read.parquet(data_path).selectExpr('count(1)')
else:
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)]
with_cpu_session(
lambda spark: gen_df(spark, gen_list, length=512).write.parquet(data_path),
conf=rebase_write_corrected_conf)
read_fn = lambda spark: spark.read.parquet(data_path)
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark: spark.read.parquet(data_path),
read_fn,
exist_classes='GpuFileSourceScanExec',
non_exist_classes='HybridFileSourceScanExec',
conf={
'spark.sql.sources.useV1SourceList': 'parquet',
'spark.rapids.sql.parquet.useHybridReader': 'true',
})



filter_split_conf = {
'spark.sql.sources.useV1SourceList': 'parquet',
'spark.rapids.sql.parquet.useHybridReader': 'true',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ object HybridExecutionUtils extends PredicateHelper {
false
}
// Currently, only support reading Parquet
lazy val isParquet = fsse.relation.fileFormat.getClass == classOf[ParquetFileFormat]
val isParquet = fsse.relation.fileFormat.getClass == classOf[ParquetFileFormat]
// Fallback to GpuScan over the `select count(1)` cases
val nonEmptySchema = fsse.output.nonEmpty && fsse.requiredSchema.nonEmpty
// Check if data types of all fields are supported by HybridParquetReader
lazy val allSupportedTypes = !fsse.requiredSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, {
Expand All @@ -78,7 +80,7 @@ object HybridExecutionUtils extends PredicateHelper {
// TODO: supports BucketedScan
lazy val noBucketedScan = !fsse.bucketedScan

isEnabled && isParquet && allSupportedTypes && noBucketedScan
isEnabled && isParquet && nonEmptySchema && allSupportedTypes && noBucketedScan
}

/**
Expand Down

0 comments on commit 7f890dd

Please sign in to comment.