-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-21717][SQL] Decouple consume functions of physical operators in whole-stage codegen #18931
Changes from all commits
05274e7
e0e7a6e
413707d
0bb8c0e
6d600d5
502139a
5fe3762
4bef567
1694c9b
8f3b984
c04da15
9540195
1101b2c
ff77bfe
e36ec3c
edb73d6
601c225
476994f
bdc1146
58eaf00
2f2d1fd
9f0d1da
79d0106
6384aec
0c4173e
c859d53
11946e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution | |
|
||
import java.util.Locale | ||
|
||
import scala.collection.mutable | ||
|
||
import org.apache.spark.broadcast | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
|
@@ -106,6 +108,31 @@ trait CodegenSupport extends SparkPlan { | |
*/ | ||
protected def doProduce(ctx: CodegenContext): String | ||
|
||
private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = { | ||
if (row != null) { | ||
ExprCode("", "false", row) | ||
} else { | ||
if (colVars.nonEmpty) { | ||
val colExprs = output.zipWithIndex.map { case (attr, i) => | ||
BoundReference(i, attr.dataType, attr.nullable) | ||
} | ||
val evaluateInputs = evaluateVariables(colVars) | ||
// generate the code to create a UnsafeRow | ||
ctx.INPUT_ROW = row | ||
ctx.currentVars = colVars | ||
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) | ||
val code = s""" | ||
|$evaluateInputs | ||
|${ev.code.trim} | ||
""".stripMargin.trim | ||
ExprCode(code, "false", ev.value) | ||
} else { | ||
// There is no columns | ||
ExprCode("", "false", "unsafeRow") | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`. | ||
* | ||
|
@@ -126,28 +153,7 @@ trait CodegenSupport extends SparkPlan { | |
} | ||
} | ||
|
||
val rowVar = if (row != null) { | ||
ExprCode("", "false", row) | ||
} else { | ||
if (outputVars.nonEmpty) { | ||
val colExprs = output.zipWithIndex.map { case (attr, i) => | ||
BoundReference(i, attr.dataType, attr.nullable) | ||
} | ||
val evaluateInputs = evaluateVariables(outputVars) | ||
// generate the code to create a UnsafeRow | ||
ctx.INPUT_ROW = row | ||
ctx.currentVars = outputVars | ||
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) | ||
val code = s""" | ||
|$evaluateInputs | ||
|${ev.code.trim} | ||
""".stripMargin.trim | ||
ExprCode(code, "false", ev.value) | ||
} else { | ||
// There is no columns | ||
ExprCode("", "false", "unsafeRow") | ||
} | ||
} | ||
val rowVar = prepareRowVar(ctx, row, outputVars) | ||
|
||
// Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` | ||
// before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to | ||
|
@@ -156,13 +162,96 @@ trait CodegenSupport extends SparkPlan { | |
ctx.INPUT_ROW = null | ||
ctx.freshNamePrefix = parent.variablePrefix | ||
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) | ||
|
||
// Under certain conditions, we can put the logic to consume the rows of this operator into | ||
// another function. So we can prevent a generated function too long to be optimized by JIT. | ||
// The conditions: | ||
// 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled. | ||
// 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses | ||
// all variables in output (see `requireAllOutput`). | ||
// 3. The number of output variables must less than maximum number of parameters in Java method | ||
// declaration. | ||
val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator | ||
val requireAllOutput = output.forall(parent.usedInputs.contains(_)) | ||
val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0) | ||
val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) { | ||
constructDoConsumeFunction(ctx, inputVars, row) | ||
} else { | ||
parent.doConsume(ctx, inputVars, rowVar) | ||
} | ||
s""" | ||
|${ctx.registerComment(s"CONSUME: ${parent.simpleString}")} | ||
|$evaluated | ||
|${parent.doConsume(ctx, inputVars, rowVar)} | ||
|$consumeFunc | ||
""".stripMargin | ||
} | ||
|
||
/** | ||
* To prevent concatenated function growing too long to be optimized by JIT. We can separate the | ||
* parent's `doConsume` codes of a `CodegenSupport` operator into a function to call. | ||
*/ | ||
private def constructDoConsumeFunction( | ||
ctx: CodegenContext, | ||
inputVars: Seq[ExprCode], | ||
row: String): String = { | ||
val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row) | ||
val rowVar = prepareRowVar(ctx, row, inputVarsInFunc) | ||
|
||
val doConsume = ctx.freshName("doConsume") | ||
ctx.currentVars = inputVarsInFunc | ||
ctx.INPUT_ROW = null | ||
|
||
val doConsumeFuncName = ctx.addNewFunction(doConsume, | ||
s""" | ||
| private void $doConsume(${params.mkString(", ")}) throws java.io.IOException { | ||
| ${parent.doConsume(ctx, inputVarsInFunc, rowVar)} | ||
| } | ||
""".stripMargin) | ||
|
||
s""" | ||
| $doConsumeFuncName(${args.mkString(", ")}); | ||
""".stripMargin | ||
} | ||
|
||
/** | ||
* Returns arguments for calling method and method definition parameters of the consume function. | ||
* And also returns the list of `ExprCode` for the parameters. | ||
*/ | ||
private def constructConsumeParameters( | ||
ctx: CodegenContext, | ||
attributes: Seq[Attribute], | ||
variables: Seq[ExprCode], | ||
row: String): (Seq[String], Seq[String], Seq[ExprCode]) = { | ||
val arguments = mutable.ArrayBuffer[String]() | ||
val parameters = mutable.ArrayBuffer[String]() | ||
val paramVars = mutable.ArrayBuffer[ExprCode]() | ||
|
||
if (row != null) { | ||
arguments += row | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to update There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should probably have 2 methods for calculating param length and checking param length limitation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added an extra unit for |
||
parameters += s"InternalRow $row" | ||
} | ||
|
||
variables.zipWithIndex.foreach { case (ev, i) => | ||
val paramName = ctx.freshName(s"expr_$i") | ||
val paramType = ctx.javaType(attributes(i).dataType) | ||
|
||
arguments += ev.value | ||
parameters += s"$paramType $paramName" | ||
val paramIsNull = if (!attributes(i).nullable) { | ||
// Use constant `false` without passing `isNull` for non-nullable variable. | ||
"false" | ||
} else { | ||
val isNull = ctx.freshName(s"exprIsNull_$i") | ||
arguments += ev.isNull | ||
parameters += s"boolean $isNull" | ||
isNull | ||
} | ||
|
||
paramVars += ExprCode("", paramIsNull, paramName) | ||
} | ||
(arguments, parameters, paramVars) | ||
} | ||
|
||
/** | ||
* Returns source code to evaluate all the variables, and clear the code of them, to prevent | ||
* them to be evaluated twice. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -205,7 +205,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { | |
val codeWithShortFunctions = genGroupByCode(3) | ||
val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) | ||
assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) | ||
val codeWithLongFunctions = genGroupByCode(20) | ||
val codeWithLongFunctions = genGroupByCode(50) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We reduced the length of generated codes. So to make this test work, we increase the number of expressions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my pr, I changed the code to just check if long functions have the larger value of max code size: |
||
val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) | ||
assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) | ||
} | ||
|
@@ -228,4 +228,49 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { | |
} | ||
} | ||
} | ||
|
||
test("Control splitting consume function by operators with config") { | ||
import testImplicits._ | ||
val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*) | ||
|
||
Seq(true, false).foreach { config => | ||
withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") { | ||
val plan = df.queryExecution.executedPlan | ||
val wholeStageCodeGenExec = plan.find(p => p match { | ||
case wp: WholeStageCodegenExec => true | ||
case _ => false | ||
}) | ||
assert(wholeStageCodeGenExec.isDefined) | ||
val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 | ||
assert(code.body.contains("project_doConsume") == config) | ||
} | ||
} | ||
} | ||
|
||
test("Skip splitting consume function when parameter number exceeds JVM limit") { | ||
import testImplicits._ | ||
|
||
Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) => | ||
withTempPath { dir => | ||
val path = dir.getCanonicalPath | ||
spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*) | ||
.write.mode(SaveMode.Overwrite).parquet(path) | ||
|
||
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", | ||
SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") { | ||
val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i") | ||
val df = spark.read.parquet(path).selectExpr(projection: _*) | ||
|
||
val plan = df.queryExecution.executedPlan | ||
val wholeStageCodeGenExec = plan.find(p => p match { | ||
case wp: WholeStageCodegenExec => true | ||
case _ => false | ||
}) | ||
assert(wholeStageCodeGenExec.isDefined) | ||
val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 | ||
assert(code.body.contains("project_doConsume") == hasSplit) | ||
} | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set to true by default. If there is objection, I can change it to false.