Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21717][SQL] Decouple consume functions of physical operators in whole-stage codegen #18931

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
05274e7
Decouple consume functions of physical operators in whole-stage codegen.
viirya Aug 13, 2017
e0e7a6e
shouldStop is called outside consume().
viirya Aug 13, 2017
413707d
Fix the condition and the case of using continue in consume.
viirya Aug 13, 2017
0bb8c0e
More comment.
viirya Aug 13, 2017
6d600d5
Fix aggregation.
viirya Aug 13, 2017
502139a
Also deal with sort case.
viirya Aug 13, 2017
5fe3762
Fix broadcasthash join.
viirya Aug 14, 2017
4bef567
Add more comments.
viirya Aug 14, 2017
1694c9b
Fix the cases where operators set up its produce framework.
viirya Aug 14, 2017
8f3b984
Fix Expand.
viirya Aug 14, 2017
c04da15
Fix BroadcastHashJoin.
viirya Aug 15, 2017
9540195
Rename variables.
viirya Aug 17, 2017
1101b2c
Don't create consume function if the number of arguments are more tha…
viirya Sep 1, 2017
ff77bfe
Merge remote-tracking branch 'upstream/master' into SPARK-21717
viirya Sep 26, 2017
e36ec3c
Remove the part of "continue" processing.
viirya Sep 26, 2017
edb73d6
Merge remote-tracking branch 'upstream/master' into SPARK-21717
viirya Oct 6, 2017
601c225
Fix test.
viirya Oct 7, 2017
476994f
More accurate calculation of valid method parameter length.
viirya Oct 11, 2017
bdc1146
Address comment.
viirya Oct 12, 2017
58eaf00
Address comments.
viirya Jan 24, 2018
2f2d1fd
Merge remote-tracking branch 'upstream/master' into SPARK-21717
viirya Jan 24, 2018
9f0d1da
Copy variables used for creating unsaferow.
viirya Jan 24, 2018
79d0106
Revert vairables copying.
viirya Jan 24, 2018
6384aec
Add final to constants.
viirya Jan 24, 2018
0c4173e
Address comments.
viirya Jan 25, 2018
c859d53
Add tests.
viirya Jan 25, 2018
11946e7
Refactor isValidParamLength a bit.
viirya Jan 25, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,24 @@ class CodegenContext {
""
}
}

/**
* In Java, a method descriptor is valid only if it represents method parameters with a total
* length of 255 or less. `this` contributes one unit and a parameter of type long or double
* contributes two units. Besides, for nullable parameters, we also need to pass a boolean
* for the null status.
*/
def isValidParamLength(params: Seq[Expression]): Boolean = {
def calculateParamLength(input: Expression): Int = {
// For a nullable expression, we need to pass in an extra boolean parameter.
(if (input.nullable) 1 else 0) + javaType(input.dataType) match {
case JAVA_LONG | JAVA_DOUBLE => 2
case _ => 1
}
}
// Initial value is 1 for `this`.
1 + params.map(calculateParamLength(_)).sum <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
}
}

/**
Expand Down Expand Up @@ -1311,26 +1329,29 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
object CodeGenerator extends Logging {

// This is the value of HugeMethodLimit in the OpenJDK JVM settings
val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000
final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000

// The max valid length of method parameters in JVM.
final val MAX_JVM_METHOD_PARAMS_LENGTH = 255

// This is the threshold over which the methods in an inner class are grouped in a single
// method which is going to be called by the outer class instead of the many small ones
val MERGE_SPLIT_METHODS_THRESHOLD = 3
final val MERGE_SPLIT_METHODS_THRESHOLD = 3

// The number of named constants that can exist in the class is limited by the Constant Pool
// limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a
// threshold of 1000k bytes to determine when a function should be inlined to a private, inner
// class.
val GENERATED_CLASS_SIZE_THRESHOLD = 1000000
final val GENERATED_CLASS_SIZE_THRESHOLD = 1000000

// This is the threshold for the number of global variables, whose types are primitive type or
// complex type (e.g. more than one-dimensional array), that will be placed at the outer class
val OUTER_CLASS_VARIABLES_THRESHOLD = 10000
final val OUTER_CLASS_VARIABLES_THRESHOLD = 10000

// This is the maximum number of array elements to keep global variables in one Java array
// 32767 is the maximum integer value that does not require a constant pool entry in a Java
// bytecode instruction
val MUTABLESTATEARRAY_SIZE_LIMIT = 32768
final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768

/**
* Compile the Java source code into a Java class, using Janino.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,15 @@ object SQLConf {
.intConf
.createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT)

val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR =
buildConf("spark.sql.codegen.splitConsumeFuncByOperator")
.internal()
.doc("When true, whole stage codegen would put the logic of consuming rows of each " +
"physical operator into individual methods, instead of a single big method. This can be " +
"used to avoid oversized function that can miss the opportunity of JIT optimization.")
.booleanConf
.createWithDefault(true)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set to true by default. If there is objection, I can change it to false.


val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
.longConf
Expand Down Expand Up @@ -1263,6 +1272,9 @@ class SQLConf extends Serializable with Logging {

def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)

def wholeStageSplitConsumeFuncByOperator: Boolean =
getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR)

def tableRelationCacheSize: Int =
getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -106,6 +108,31 @@ trait CodegenSupport extends SparkPlan {
*/
protected def doProduce(ctx: CodegenContext): String

private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = {
if (row != null) {
ExprCode("", "false", row)
} else {
if (colVars.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
val evaluateInputs = evaluateVariables(colVars)
// generate the code to create a UnsafeRow
ctx.INPUT_ROW = row
ctx.currentVars = colVars
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
val code = s"""
|$evaluateInputs
|${ev.code.trim}
""".stripMargin.trim
ExprCode(code, "false", ev.value)
} else {
// There is no columns
ExprCode("", "false", "unsafeRow")
}
}
}

/**
* Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`.
*
Expand All @@ -126,28 +153,7 @@ trait CodegenSupport extends SparkPlan {
}
}

val rowVar = if (row != null) {
ExprCode("", "false", row)
} else {
if (outputVars.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
val evaluateInputs = evaluateVariables(outputVars)
// generate the code to create a UnsafeRow
ctx.INPUT_ROW = row
ctx.currentVars = outputVars
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
val code = s"""
|$evaluateInputs
|${ev.code.trim}
""".stripMargin.trim
ExprCode(code, "false", ev.value)
} else {
// There is no columns
ExprCode("", "false", "unsafeRow")
}
}
val rowVar = prepareRowVar(ctx, row, outputVars)

// Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
// before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
Expand All @@ -156,13 +162,96 @@ trait CodegenSupport extends SparkPlan {
ctx.INPUT_ROW = null
ctx.freshNamePrefix = parent.variablePrefix
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)

// Under certain conditions, we can put the logic to consume the rows of this operator into
Copy link
Member

@kiszk kiszk Aug 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate certain conditions in the comment if you have time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more comment to elaborate the idea.

// another function. So we can prevent a generated function too long to be optimized by JIT.
// The conditions:
// 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
// 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
// all variables in output (see `requireAllOutput`).
// 3. The number of output variables must less than maximum number of parameters in Java method
// declaration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern is if we have a bunch of simple operators and we create a lot of small methods here. Maybe it's fine as optimizer would prevent such cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can be super safe and only do this for certain operators, like HashAggregateExec.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or introduce a config so that users can turn it off.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a config for it so we can turn it off.

val requireAllOutput = output.forall(parent.usedInputs.contains(_))
val consumeFunc =
if (SQLConf.get.wholeStageSplitConsumeFuncByOperator && requireAllOutput &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit:

val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator
if (confEnabled && ...)

ctx.isValidParamLength(output)) {
constructDoConsumeFunction(ctx, inputVars, row)
} else {
parent.doConsume(ctx, inputVars, rowVar)
}
s"""
|${ctx.registerComment(s"CONSUME: ${parent.simpleString}")}
|$evaluated
|${parent.doConsume(ctx, inputVars, rowVar)}
|$consumeFunc
""".stripMargin
}

/**
* To prevent concatenated function growing too long to be optimized by JIT. We can separate the
* parent's `doConsume` codes of a `CodegenSupport` operator into a function to call.
*/
private def constructDoConsumeFunction(
ctx: CodegenContext,
inputVars: Seq[ExprCode],
row: String): String = {
val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)

val doConsume = ctx.freshName("doConsume")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we put the operator name in this function name?

Copy link
Member Author

@viirya viirya Jan 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The freshName here will add variablePrefix before doConsume. So it already has operator name, e.g., agg_doConsume.

ctx.currentVars = inputVarsInFunc
ctx.INPUT_ROW = null

val doConsumeFuncName = ctx.addNewFunction(doConsume,
s"""
| private void $doConsume(${params.mkString(", ")}) throws java.io.IOException {
| ${parent.doConsume(ctx, inputVarsInFunc, rowVar)}
| }
""".stripMargin)

s"""
| $doConsumeFuncName(${args.mkString(", ")});
""".stripMargin
}

/**
* Returns arguments for calling method and method definition parameters of the consume function.
* And also returns the list of `ExprCode` for the parameters.
*/
private def constructConsumeParameters(
ctx: CodegenContext,
attributes: Seq[Attribute],
variables: Seq[ExprCode],
row: String): (Seq[String], Seq[String], Seq[ExprCode]) = {
val arguments = mutable.ArrayBuffer[String]()
val parameters = mutable.ArrayBuffer[String]()
val paramVars = mutable.ArrayBuffer[ExprCode]()

if (row != null) {
arguments += row
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to update ctx.isValidParamLength to count this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably have 2 methods for calculating param length and checking param length limitation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an extra unit for row if needed.

parameters += s"InternalRow $row"
}

variables.zipWithIndex.foreach { case (ev, i) =>
val paramName = ctx.freshName(s"expr_$i")
val paramType = ctx.javaType(attributes(i).dataType)

arguments += ev.value
parameters += s"$paramType $paramName"
val paramIsNull = if (!attributes(i).nullable) {
// Use constant `false` without passing `isNull` for non-nullable variable.
"false"
} else {
val isNull = ctx.freshName(s"exprIsNull_$i")
arguments += ev.isNull
parameters += s"boolean $isNull"
isNull
}

paramVars += ExprCode("", paramIsNull, paramName)
}
(arguments, parameters, paramVars)
}

/**
* Returns source code to evaluate all the variables, and clear the code of them, to prevent
* them to be evaluated twice.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
val codeWithShortFunctions = genGroupByCode(3)
val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions)
assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
val codeWithLongFunctions = genGroupByCode(20)
val codeWithLongFunctions = genGroupByCode(50)
Copy link
Member Author

@viirya viirya Oct 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We reduced the length of generated codes. So to make this test work, we increase the number of expressions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my pr, I changed the code to just check if long functions have the larger value of max code size:
https://github.com/apache/spark/pull/19082/files#diff-0314224342bb8c30143ab784b3805d19R185
but, just increasing the value seems better.

val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions)
assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
}
Expand All @@ -228,4 +228,49 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
}
}
}

test("Control splitting consume function by operators with config") {
import testImplicits._
val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*)

Seq(true, false).foreach { config =>
withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") {
val plan = df.queryExecution.executedPlan
val wholeStageCodeGenExec = plan.find(p => p match {
case wp: WholeStageCodegenExec => true
case _ => false
})
assert(wholeStageCodeGenExec.isDefined)
val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
assert(code.body.contains("project_doConsume") == config)
}
}
}

test("Skip splitting consume function when parameter number exceeds JVM limit") {
import testImplicits._

Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) =>
withTempPath { dir =>
val path = dir.getCanonicalPath
spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*)
.write.mode(SaveMode.Overwrite).parquet(path)

withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") {
val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i")
val df = spark.read.parquet(path).selectExpr(projection: _*)

val plan = df.queryExecution.executedPlan
val wholeStageCodeGenExec = plan.find(p => p match {
case wp: WholeStageCodegenExec => true
case _ => false
})
assert(wholeStageCodeGenExec.isDefined)
val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
assert(code.body.contains("project_doConsume") == hasSplit)
}
}
}
}
}