Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21870][SQL] Split aggregation code into small functions #19082

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen

import java.io.ByteArrayInputStream
import java.lang.Character._
import java.util.{Map => JavaMap}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -1103,6 +1104,29 @@ class CodegenContext {
}
}

object CodegenContext {

private val javaKeywords = Set(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need add enum?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enum looks over kill for now

"abstract", "assert", "boolean", "break", "byte", "case", "catch", "char", "class", "const",
"continue", "default", "do", "double", "else", "extends", "false", "final", "finally", "float",
"for", "goto", "if", "implements", "import", "instanceof", "int", "interface", "long", "native",
"new", "null", "package", "private", "protected", "public", "return", "short", "static",
"strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true",
"try", "void", "volatile", "while"
)

/**
* Returns true if the given `str` is a valid java identifier.
*/
def isJavaIdentifier(str: String): Boolean = str match {
case null | "" =>
false
case _ =>
!javaKeywords.contains(str) && isJavaIdentifierStart(str.charAt(0)) &&
(1 until str.length).forall(i => isJavaIdentifierPart(str.charAt(i)))
}
}

/**
* A wrapper for generated class, defines a `generate` method so that we can pass extra objects
* into generated class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,17 @@ object SQLConf {
.intConf
.createWithDefault(100)

val MAX_PARAM_NUM_IN_JAVA_METHOD =
buildConf("spark.sql.codegen.maxParamNumInJavaMethod")
.internal()
.doc("The maximum number of parameters in codegened Java functions. When a function " +
"exceeds this threshold, the code generator gives up splitting the function code. " +
"This default value is 127 because the maximum length of parameters in non-static Java " +
"methods is 254 and a parameter of type long or double contributes " +
"two units to the length.")
.intConf
.createWithDefault(127)

val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback")
.internal()
.doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" +
Expand Down Expand Up @@ -1156,6 +1167,8 @@ class SQLConf extends Serializable with Logging {

def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)

def maxParamNumInJavaMethod: Int = getConf(MAX_PARAM_NUM_IN_JAVA_METHOD)

def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)

def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,4 +394,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
Map("add" -> Literal(1))).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
}

test("SPARK-21870 check if CodegenContext.isJavaIdentifier works correctly") {
import CodegenContext.isJavaIdentifier
// positive cases
assert(isJavaIdentifier("agg_value"))
assert(isJavaIdentifier("agg_value1"))
assert(isJavaIdentifier("bhj_value4"))
assert(isJavaIdentifier("smj_value6"))
assert(isJavaIdentifier("rdd_value7"))
assert(isJavaIdentifier("scan_isNull"))
assert(isJavaIdentifier("test"))
// negative cases
assert(!isJavaIdentifier("true"))
assert(!isJavaIdentifier("false"))
assert(!isJavaIdentifier("390239"))
assert(!isJavaIdentifier(""""literal""""))
assert(!isJavaIdentifier(""""double""""))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.aggregate

import scala.collection.mutable

import org.apache.spark.TaskContext
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -257,6 +259,78 @@ case class HashAggregateExec(
""".stripMargin
}

// Extracts all the input variable references for a given `aggExpr`. This result will be used
// to split aggregation into small functions.
private def getInputVariableReferences(
context: CodegenContext,
aggregateExpression: Expression,
subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = {
// `argSet` collects all the pairs of variable names and their types, the first in the pair is
// a type name and the second is a variable name.
val argSet = mutable.Set[(String, String)]()
val stack = mutable.Stack[Expression](aggregateExpression)
while (stack.nonEmpty) {
stack.pop() match {
case e if subExprs.contains(e) =>
val exprCode = subExprs(e)
if (CodegenContext.isJavaIdentifier(exprCode.value)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we have @viirya 's #20043 merged we won't need the ugly CodegenContext.isJavaIdentifier hack any more >_<|||

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey, good news! Thanks for letting me know ;)

argSet += ((context.javaType(e.dataType), exprCode.value))
}
if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
argSet += (("boolean", exprCode.isNull))
}
// Since the children possibly has common expressions, we push them here
stack.pushAll(e.children)
case ref: BoundReference
if context.currentVars != null && context.currentVars(ref.ordinal) != null =>
val value = context.currentVars(ref.ordinal).value
val isNull = context.currentVars(ref.ordinal).isNull
if (CodegenContext.isJavaIdentifier(value)) {
argSet += ((context.javaType(ref.dataType), value))
}
if (CodegenContext.isJavaIdentifier(isNull)) {
argSet += (("boolean", isNull))
}
case _: BoundReference =>
argSet += (("InternalRow", context.INPUT_ROW))
case e =>
stack.pushAll(e.children)
}
}

argSet.toSet
}

// Splits aggregate code into small functions because JVMs does not compile too long functions
private def splitAggregateExpressions(
context: CodegenContext,
aggregateExpressions: Seq[Expression],
codes: Seq[String],
subExprs: Map[Expression, SubExprEliminationState],
otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
aggregateExpressions.zipWithIndex.map { case (aggExpr, i) =>
val args = (getInputVariableReferences(context, aggExpr, subExprs) ++ otherArgs).toSeq

// This method gives up splitting the code if the parameter length goes over
// `maxParamNumInJavaMethod`.
if (args.size <= sqlContext.conf.maxParamNumInJavaMethod) {
val doAggVal = context.freshName(s"doAggregateVal_${aggExpr.prettyName}")
val argList = args.map(a => s"${a._1} ${a._2}").mkString(", ")
val doAggValFuncName = context.addNewFunction(doAggVal,
s"""
| private void $doAggVal($argList) throws java.io.IOException {
| ${codes(i)}
| }
""".stripMargin)

val inputVariables = args.map(_._2).mkString(", ")
s"$doAggValFuncName($inputVariables);"
} else {
codes(i)
}
}
}

private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
Expand All @@ -269,28 +343,53 @@ case class HashAggregateExec(
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
}
ctx.currentVars = bufVars ++ input

// We need to copy the aggregation buffer to local variables first because each aggregate
// function directly updates the buffer when it finishes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just FYI: we must need local copys from this discussions, too #19865

val localBufVars = bufVars.zip(updateExpr).map { case (ev, e) =>
val isNull = ctx.freshName("localBufIsNull")
val value = ctx.freshName("localBufValue")
val initLocalVars = s"""
| boolean $isNull = ${ev.isNull};
| ${ctx.javaType(e.dataType)} $value = ${ev.value};
""".stripMargin
ExprCode(initLocalVars, isNull, value)
}

val initLocalBufVar = evaluateVariables(localBufVars)

ctx.currentVars = localBufVars ++ input
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
// aggregate buffer should be updated atomic
val updates = aggVals.zipWithIndex.map { case (ev, i) =>

val evalAndUpdateCodes = aggVals.zipWithIndex.map { case (ev, i) =>
s"""
| // evaluate aggregate function
| ${ev.code}
| // update aggregation buffer
| ${bufVars(i).isNull} = ${ev.isNull};
| ${bufVars(i).value} = ${ev.value};
""".stripMargin
}

val updateAggValCode = splitAggregateExpressions(
context = ctx,
aggregateExpressions = boundUpdateExpr,
codes = evalAndUpdateCodes,
subExprs = subExprs.states)

s"""
| // do aggregate
| // copy aggregation buffer to the local
| $initLocalBufVar
| // common sub-expressions
| $effectiveCodes
| // evaluate aggregate function
| ${evaluateVariables(aggVals)}
| // update aggregation buffer
| ${updates.mkString("\n").trim}
| // process aggregate functions to update aggregation buffer
| ${updateAggValCode.mkString("\n")}
""".stripMargin
}

Expand Down Expand Up @@ -825,52 +924,92 @@ case class HashAggregateExec(
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input

val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
// We need to copy the aggregation row buffer to a local row first because each aggregate
// function directly updates the buffer when it finishes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this matter? We should avoid unnecessary data copy as possible as we can.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this copy because: #19082 (comment)

val localRowBuffer = ctx.freshName("localUnsafeRowBuffer")
val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();"

ctx.INPUT_ROW = localRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>

val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
s"""
| // evaluate aggregate function
| ${ev.code}
| // update unsafe row buffer
| $updateColumnCode
""".stripMargin
}

val updateAggValCode = splitAggregateExpressions(
context = ctx,
aggregateExpressions = boundUpdateExpr,
codes = evalAndUpdateCodes,
subExprs = subExprs.states,
otherArgs = Seq(("InternalRow", unsafeRowBuffer)))

s"""
|// common sub-expressions
|$effectiveCodes
|// evaluate aggregate function
|${evaluateVariables(unsafeRowBufferEvals)}
|// update unsafe row buffer
|${updateUnsafeRowBuffer.mkString("\n").trim}
| // do aggregate
| // copy aggregation row buffer to the local
| $initLocalRowBuffer
| // common sub-expressions
| $effectiveCodes
| // process aggregate functions to update aggregation buffer
| ${updateAggValCode.mkString("\n")}
""".stripMargin
}

val updateRowInHashMap: String = {
if (isFastHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
// We need to copy the aggregation row buffer to a local row first because each aggregate
// function directly updates the buffer when it finishes.
val localRowBuffer = ctx.freshName("localFastRowBuffer")
val initLocalRowBuffer = s"InternalRow $localRowBuffer = $fastRowBuffer.copy();"

ctx.INPUT_ROW = localRowBuffer
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>

val evalAndUpdateCodes = fastRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(
fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled)
val updateColumnCode = ctx.updateColumn(
fastRowBuffer, dt, i, ev, updateExpr(i).nullable)
s"""
| // evaluate aggregate function
| ${ev.code}
| // update fast row
| $updateColumnCode
""".stripMargin
}

val updateAggValCode = splitAggregateExpressions(
context = ctx,
aggregateExpressions = boundUpdateExpr,
codes = evalAndUpdateCodes,
subExprs = subExprs.states,
otherArgs = Seq(("InternalRow", fastRowBuffer)))

// If fast hash map is on, we first generate code to update row in fast hash map, if the
// previous loop up hit fast hash map. Otherwise, update row in regular hash map.
s"""
|if ($fastRowBuffer != null) {
| // copy aggregation row buffer to the local
| $initLocalRowBuffer
| // common sub-expressions
| $effectiveCodes
| // evaluate aggregate function
| ${evaluateVariables(fastRowEvals)}
| // update fast row
| ${updateFastRow.mkString("\n").trim}
| // process aggregate functions to update aggregation buffer
| ${updateAggValCode.mkString("\n")}
|} else {
| $updateRowInRegularHashMap
|}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,11 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {

test("SPARK-21871 check if we can get large code size when compiling too long functions") {
val codeWithShortFunctions = genGroupByCode(3)
val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions)
assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
val (_, smallCodeSize) = CodeGenerator.compile(codeWithShortFunctions)
val codeWithLongFunctions = genGroupByCode(20)
val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions)
assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
val (_, largeCodeSize) = CodeGenerator.compile(codeWithLongFunctions)
// Just checking if long functions have the large value of max code size
assert(largeCodeSize > smallCodeSize)
}

test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") {
Expand All @@ -236,4 +236,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-21870 check the case where the number of parameters goes over the limit") {
withSQLConf("spark.sql.codegen.maxParamNumInJavaMethod" -> "2") {
sql("CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES (1, 1, 1) AS t(a, b, c)")
val df = sql("SELECT SUM(a + b + c) AS sum FROM t")
assert(df.collect === Seq(Row(3)))
}
}
}