Skip to content

Commit

Permalink
[SPARK-21717][SQL] Decouple consume functions of physical operators i…
Browse files Browse the repository at this point in the history
…n whole-stage codegen

## What changes were proposed in this pull request?

It has been observed in SPARK-21603 that whole-stage codegen suffers performance degradation, if the generated functions are too long to be optimized by JIT.

We basically produce a single function to incorporate generated codes from all physical operators in whole-stage. Thus, it is possibly to grow the size of generated function over a threshold that we can't have JIT optimization for it anymore.

This patch is trying to decouple the logic of consuming rows in physical operators to avoid a giant function processing rows.

## How was this patch tested?

Added tests.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #18931 from viirya/SPARK-21717.
  • Loading branch information
viirya authored and cloud-fan committed Jan 25, 2018
1 parent 39ee2ac commit d20bbc2
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,31 @@ class CodegenContext {
""
}
}

/**
* Returns the length of parameters for a Java method descriptor. `this` contributes one unit
* and a parameter of type long or double contributes two units. Besides, for nullable parameter,
* we also need to pass a boolean parameter for the null status.
*/
def calculateParamLength(params: Seq[Expression]): Int = {
def paramLengthForExpr(input: Expression): Int = {
// For a nullable expression, we need to pass in an extra boolean parameter.
(if (input.nullable) 1 else 0) + javaType(input.dataType) match {
case JAVA_LONG | JAVA_DOUBLE => 2
case _ => 1
}
}
// Initial value is 1 for `this`.
1 + params.map(paramLengthForExpr(_)).sum
}

/**
* In Java, a method descriptor is valid only if it represents method parameters with a total
* length less than a pre-defined constant.
*/
def isValidParamLength(paramLength: Int): Boolean = {
paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
}
}

/**
Expand Down Expand Up @@ -1311,26 +1336,29 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
object CodeGenerator extends Logging {

// This is the value of HugeMethodLimit in the OpenJDK JVM settings
val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000
final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000

// The max valid length of method parameters in JVM.
final val MAX_JVM_METHOD_PARAMS_LENGTH = 255

// This is the threshold over which the methods in an inner class are grouped in a single
// method which is going to be called by the outer class instead of the many small ones
val MERGE_SPLIT_METHODS_THRESHOLD = 3
final val MERGE_SPLIT_METHODS_THRESHOLD = 3

// The number of named constants that can exist in the class is limited by the Constant Pool
// limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a
// threshold of 1000k bytes to determine when a function should be inlined to a private, inner
// class.
val GENERATED_CLASS_SIZE_THRESHOLD = 1000000
final val GENERATED_CLASS_SIZE_THRESHOLD = 1000000

// This is the threshold for the number of global variables, whose types are primitive type or
// complex type (e.g. more than one-dimensional array), that will be placed at the outer class
val OUTER_CLASS_VARIABLES_THRESHOLD = 10000
final val OUTER_CLASS_VARIABLES_THRESHOLD = 10000

// This is the maximum number of array elements to keep global variables in one Java array
// 32767 is the maximum integer value that does not require a constant pool entry in a Java
// bytecode instruction
val MUTABLESTATEARRAY_SIZE_LIMIT = 32768
final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768

/**
* Compile the Java source code into a Java class, using Janino.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,15 @@ object SQLConf {
.intConf
.createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT)

val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR =
buildConf("spark.sql.codegen.splitConsumeFuncByOperator")
.internal()
.doc("When true, whole stage codegen would put the logic of consuming rows of each " +
"physical operator into individual methods, instead of a single big method. This can be " +
"used to avoid oversized function that can miss the opportunity of JIT optimization.")
.booleanConf
.createWithDefault(true)

val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
.longConf
Expand Down Expand Up @@ -1263,6 +1272,9 @@ class SQLConf extends Serializable with Logging {

def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)

def wholeStageSplitConsumeFuncByOperator: Boolean =
getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR)

def tableRelationCacheSize: Int =
getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -106,6 +108,31 @@ trait CodegenSupport extends SparkPlan {
*/
protected def doProduce(ctx: CodegenContext): String

private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = {
if (row != null) {
ExprCode("", "false", row)
} else {
if (colVars.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
val evaluateInputs = evaluateVariables(colVars)
// generate the code to create a UnsafeRow
ctx.INPUT_ROW = row
ctx.currentVars = colVars
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
val code = s"""
|$evaluateInputs
|${ev.code.trim}
""".stripMargin.trim
ExprCode(code, "false", ev.value)
} else {
// There is no columns
ExprCode("", "false", "unsafeRow")
}
}
}

/**
* Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`.
*
Expand All @@ -126,28 +153,7 @@ trait CodegenSupport extends SparkPlan {
}
}

val rowVar = if (row != null) {
ExprCode("", "false", row)
} else {
if (outputVars.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
val evaluateInputs = evaluateVariables(outputVars)
// generate the code to create a UnsafeRow
ctx.INPUT_ROW = row
ctx.currentVars = outputVars
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
val code = s"""
|$evaluateInputs
|${ev.code.trim}
""".stripMargin.trim
ExprCode(code, "false", ev.value)
} else {
// There is no columns
ExprCode("", "false", "unsafeRow")
}
}
val rowVar = prepareRowVar(ctx, row, outputVars)

// Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
// before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
Expand All @@ -156,13 +162,96 @@ trait CodegenSupport extends SparkPlan {
ctx.INPUT_ROW = null
ctx.freshNamePrefix = parent.variablePrefix
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)

// Under certain conditions, we can put the logic to consume the rows of this operator into
// another function. So we can prevent a generated function too long to be optimized by JIT.
// The conditions:
// 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
// 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
// all variables in output (see `requireAllOutput`).
// 3. The number of output variables must less than maximum number of parameters in Java method
// declaration.
val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator
val requireAllOutput = output.forall(parent.usedInputs.contains(_))
val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0)
val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) {
constructDoConsumeFunction(ctx, inputVars, row)
} else {
parent.doConsume(ctx, inputVars, rowVar)
}
s"""
|${ctx.registerComment(s"CONSUME: ${parent.simpleString}")}
|$evaluated
|${parent.doConsume(ctx, inputVars, rowVar)}
|$consumeFunc
""".stripMargin
}

/**
* To prevent concatenated function growing too long to be optimized by JIT. We can separate the
* parent's `doConsume` codes of a `CodegenSupport` operator into a function to call.
*/
private def constructDoConsumeFunction(
ctx: CodegenContext,
inputVars: Seq[ExprCode],
row: String): String = {
val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)

val doConsume = ctx.freshName("doConsume")
ctx.currentVars = inputVarsInFunc
ctx.INPUT_ROW = null

val doConsumeFuncName = ctx.addNewFunction(doConsume,
s"""
| private void $doConsume(${params.mkString(", ")}) throws java.io.IOException {
| ${parent.doConsume(ctx, inputVarsInFunc, rowVar)}
| }
""".stripMargin)

s"""
| $doConsumeFuncName(${args.mkString(", ")});
""".stripMargin
}

/**
* Returns arguments for calling method and method definition parameters of the consume function.
* And also returns the list of `ExprCode` for the parameters.
*/
private def constructConsumeParameters(
ctx: CodegenContext,
attributes: Seq[Attribute],
variables: Seq[ExprCode],
row: String): (Seq[String], Seq[String], Seq[ExprCode]) = {
val arguments = mutable.ArrayBuffer[String]()
val parameters = mutable.ArrayBuffer[String]()
val paramVars = mutable.ArrayBuffer[ExprCode]()

if (row != null) {
arguments += row
parameters += s"InternalRow $row"
}

variables.zipWithIndex.foreach { case (ev, i) =>
val paramName = ctx.freshName(s"expr_$i")
val paramType = ctx.javaType(attributes(i).dataType)

arguments += ev.value
parameters += s"$paramType $paramName"
val paramIsNull = if (!attributes(i).nullable) {
// Use constant `false` without passing `isNull` for non-nullable variable.
"false"
} else {
val isNull = ctx.freshName(s"exprIsNull_$i")
arguments += ev.isNull
parameters += s"boolean $isNull"
isNull
}

paramVars += ExprCode("", paramIsNull, paramName)
}
(arguments, parameters, paramVars)
}

/**
* Returns source code to evaluate all the variables, and clear the code of them, to prevent
* them to be evaluated twice.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
val codeWithShortFunctions = genGroupByCode(3)
val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions)
assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
val codeWithLongFunctions = genGroupByCode(20)
val codeWithLongFunctions = genGroupByCode(50)
val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions)
assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
}
Expand All @@ -228,4 +228,49 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
}
}
}

test("Control splitting consume function by operators with config") {
import testImplicits._
val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*)

Seq(true, false).foreach { config =>
withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") {
val plan = df.queryExecution.executedPlan
val wholeStageCodeGenExec = plan.find(p => p match {
case wp: WholeStageCodegenExec => true
case _ => false
})
assert(wholeStageCodeGenExec.isDefined)
val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
assert(code.body.contains("project_doConsume") == config)
}
}
}

test("Skip splitting consume function when parameter number exceeds JVM limit") {
import testImplicits._

Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) =>
withTempPath { dir =>
val path = dir.getCanonicalPath
spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*)
.write.mode(SaveMode.Overwrite).parquet(path)

withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") {
val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i")
val df = spark.read.parquet(path).selectExpr(projection: _*)

val plan = df.queryExecution.executedPlan
val wholeStageCodeGenExec = plan.find(p => p match {
case wp: WholeStageCodegenExec => true
case _ => false
})
assert(wholeStageCodeGenExec.isDefined)
val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
assert(code.body.contains("project_doConsume") == hasSplit)
}
}
}
}
}

0 comments on commit d20bbc2

Please sign in to comment.