Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Case when performance improvement: reduce the copy_if_else [databricks] #10951

Merged
merged 15 commits into from
Jul 12, 2024

Conversation

res-life
Copy link
Collaborator

@res-life res-life commented May 30, 2024

contributes to #2084

Current case when logic: link

Iteratively invoke copy_if_else to merge the tail 2 branches.
This will create lots of temp string columns via copy_if_else which will introduce more spend time.

      val elseRet = elseValue
        .map(_.columnarEvalAny(batch))
        .getOrElse(GpuScalar(null, branches.last._2.dataType))
      val any = branches.foldRight[Any](elseRet) {
        case ((predicateExpr, trueExpr), falseRet) =>
          computeIfElse(batch, predicateExpr, trueExpr, falseRet)
      }

new logic

For the case when with multiple branches and the then values are all string scalars.
e.g.:

CASE 
  WHEN expr1 THEN 's1'
  WHEN expr2 THEN 's2'
  WHEN expr2 THEN 's3' 
   ...
  ELSE 's-default'
END

This PR removed the copy_if_elses.
Steps:

  • First evaluate all the expr1 expr2...., and get bool columns.
  • Generate a column for all the scalars.
  • Then select from bool table(bool columns), get all the expected index of scalars for each row, and generate a int column contains scalar indexes.
  • Then select scalars from scalar column according to the selecar indexes, and generate a String column.

For more details, refer to JNI PR: NVIDIA/spark-rapids-jni#2079
Depending on PR: NVIDIA/spark-rapids-jni#2079

Signed-off-by: Chong Gao res_life@163.com

Signed-off-by: Chong Gao <res_life@163.com>
@res-life res-life changed the base branch from branch-24.06 to branch-24.08 May 30, 2024 08:38
Copy link
Collaborator

@winningsix winningsix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some UTs here.

) {
// return type is string type; all the then and else exprs are Scalars
// avoid to use multiple `computeIfElse`s which will create multiple temp columns
logWarning("==================== Running case with experimental =========== ")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to remove this debug log?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove

logWarning("==================== Running case with experimental =========== ")

val whenBoolCols = branches.safeMap(_._1.columnarEval(batch).getBase).toArray
val firstTrueIndex: ColumnVector = withResource(whenBoolCols) { _ =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: selectivityCol better name?

@res-life
Copy link
Collaborator Author

Have build error because of depending on JNI PR.

@res-life
Copy link
Collaborator Author

res-life commented Jun 3, 2024

build

@res-life
Copy link
Collaborator Author

res-life commented Jun 3, 2024

End to End Perf test result:

case when branches legacy min time new impl min time speedup
3 147ms 129ms 1.14x
10 298ms 197ms 1.51x

Test steps:

  • create 5 parquet files with 10 bool columns in it.
    Each file contains 10,000,000 rows:
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
|  b01|  b02|  b03|  b04|  b05|  b06|  b07|  b08|  b09|  b10|
+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
|false| true|false|false| true|false|false| true|false|false|
| true|false| true|false|false|false| true| true|false|false|
|false| true| true| true| true|false| true|false|false| true|
| true| true|false|false| true|false|false|false| true|false|
|false| true|false| true|false|false|false| true| true| true|
| true|false|false| true|false|false|false| true| true| true|
|false|false|false|false| true| true|false|false|false| true|
  • SQL
    10 branches:
"""
select
  sum(octet_length(
    CASE WHEN b01 THEN 'go-http-client'
         WHEN b02 THEN 'libcurlOT'
         WHEN b03 THEN 'libcurl'
         WHEN b04 THEN 'wspcdn'
         WHEN b05 THEN 'byte-pcdn-vod'
         WHEN b06 THEN 'byte-pcdn-live'
         WHEN b07 THEN 'byte-pcdn/vod-go'
         WHEN b08 THEN 'byte-pcdn/vod-bs'
         WHEN b09 THEN 'kcg-cache'
         WHEN b10 THEN 'byte-pcdn-vod'
         ELSE 'unknown'
    END
  ))
from tab
"""

3 branches:

"""
select
  sum(octet_length(
    CASE WHEN b01 THEN 'go-http-client'
         WHEN b02 THEN 'libcurlOT'
         WHEN b03 THEN 'libcurl'
         ELSE 'unknown'
    END
  ))
from tab
"""

@res-life res-life marked this pull request as ready for review June 3, 2024 12:44
@sameerz sameerz added the performance A performance related task/issue label Jun 3, 2024
@res-life
Copy link
Collaborator Author

res-life commented Jun 4, 2024

This PR only improves the case that all when expressions are scalars.
For the general case that when expressions are not all scalars, we may file a follow-up to imporve.
IMO, for the case that return type is string, the method in this PR will help; For other type, it may not help much.

This PR only improves the non side effect code.

case ((predicateExpr, trueExpr), falseRet) =>
computeIfElse(batch, predicateExpr, trueExpr, falseRet)
if (caseWhenFuseEnabled && branches.size > 2 &&
inputTypesForMerging.head == StringType &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why just strings? The gather used should be generic enough to do anything.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other type, need to test, will update if get perf improvement for other types.

@res-life
Copy link
Collaborator Author

build

2 similar comments
@res-life
Copy link
Collaborator Author

build

@res-life
Copy link
Collaborator Author

build

@res-life
Copy link
Collaborator Author

res-life commented Jul 4, 2024

Perf tests for other types:

type when branches num new MS(enable optimazation) (2nd,3rd,4th) baseline MS(disable optimazation) (2nd,3rd,4th) end to end speedup (new MS/baseline MS)
bool 3 133, 124, 121 144, 139, 139 1.11640211640212
int 3 159,148, 146 160, 145, 157 1.01986754966887
byte 3 127, 140, 130 151, 147, 152 1.13350125944584
float 3 120, 133, 131 153, 143, 149 1.15885416666667
decimal32 3 126, 135, 122 140, 141, 133 1.08093994778068
bool 10 172, 199, 205 267, 280, 280 1.43576388888889
int 10 215,215, 207 263, 273, 266 1.25902668759812
byte 10 210, 180, 201 261, 287, 289 1.41624365482234
float 10 194, 177, 209 281, 262, 275 1.41034482758621
decimal32 10 177, 194, 200 265, 268, 270 1.40630472854641

@res-life
Copy link
Collaborator Author

res-life commented Jul 4, 2024

build

@res-life res-life changed the title Case when performance improvement: reduce the copy_if_else Case when performance improvement: reduce the copy_if_else [databricks] Jul 4, 2024
@res-life
Copy link
Collaborator Author

res-life commented Jul 4, 2024

@revans2 Please help review.

@res-life
Copy link
Collaborator Author

res-life commented Jul 5, 2024

build

@res-life res-life requested a review from revans2 July 9, 2024 01:23
@@ -296,3 +296,85 @@ def test_conditional_with_side_effects_unary_minus(data_gen, ansi_enabled):
'CASE WHEN a > -32768 THEN -a ELSE null END'),
conf = {'spark.sql.ansi.enabled': ansi_enabled})

_case_when_scalars = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to leverage the data_gen to build a more robust test case just as other test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case test_case_when_all_then_values_are_scalars, for the bool columns, the data_gen is as following:

    data_gen = [
        ("a", boolean_gen),
        ("b", boolean_gen),
        ("c", boolean_gen),
        ("d", boolean_gen),
        ("e", boolean_gen)
    ]

    sql =  """
            select case
                when a then {}
                when b then {}
                when c then {}
                when d then {}
                else {}
            end
            from tab
            """

The _case_when_scalars specify the const scalars, data_gen can not be used.

@res-life
Copy link
Collaborator Author

build

@res-life
Copy link
Collaborator Author

build

@firestarman
Copy link
Collaborator

LGTM

@res-life
Copy link
Collaborator Author

@res-life
Copy link
Collaborator Author

build

Copy link
Collaborator

@firestarman firestarman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM again

Copy link
Collaborator

@revans2 revans2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Learned something new about the unapply in Spark for decimals. Thanks.

@res-life res-life merged commit 451463f into NVIDIA:branch-24.08 Jul 12, 2024
45 checks passed
@res-life res-life deleted the case-when-perf branch July 12, 2024 01:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance A performance related task/issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants