Skip to content

Commit

Permalink
Fix a hive write test failure (#10958)
Browse files Browse the repository at this point in the history
This is a bug fix for the hive write tests. In some of the tests on Spak 351, 
the ProjectExec will fall back to CPU due to missing the GPU version of the MapFromArrays expression.

This PR adds the ProjectExec to the allowed list of fallback for Spark 351 and the laters.

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman authored Jun 5, 2024
1 parent 3111e2b commit 149e0d5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/hive_parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from data_gen import *
from hive_write_test import _restricted_timestamp
from marks import allow_non_gpu, ignore_order
from spark_session import with_cpu_session, is_before_spark_320
from spark_session import with_cpu_session, is_before_spark_320, is_spark_351_or_later

# Disable the meta conversion from Hive write to FrameData write in Spark, to test
# "GpuInsertIntoHiveTable" for Parquet write.
Expand Down Expand Up @@ -55,7 +55,7 @@
_hive_write_gens = [_hive_basic_gens, _hive_struct_gens, _hive_array_gens, _hive_map_gens]

# ProjectExec falls back on databricks due to no GPU version of "MapFromArrays".
fallback_nodes = ['ProjectExec'] if is_databricks_runtime() else []
fallback_nodes = ['ProjectExec'] if is_databricks_runtime() or is_spark_351_or_later() else []


@allow_non_gpu(*(non_utc_allow + fallback_nodes))
Expand Down
5 changes: 4 additions & 1 deletion integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -220,6 +220,9 @@ def is_spark_341():
def is_spark_350_or_later():
return spark_version() >= "3.5.0"

def is_spark_351_or_later():
return spark_version() >= "3.5.1"

def is_spark_330():
return spark_version() == "3.3.0"

Expand Down

0 comments on commit 149e0d5

Please sign in to comment.