-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-21870][SQL] Split aggregation code into small functions #19082
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ | |
|
||
package org.apache.spark.sql.execution.aggregate | ||
|
||
import scala.collection.mutable | ||
|
||
import org.apache.spark.TaskContext | ||
import org.apache.spark.memory.TaskMemoryManager | ||
import org.apache.spark.rdd.RDD | ||
|
@@ -257,6 +259,78 @@ case class HashAggregateExec( | |
""".stripMargin | ||
} | ||
|
||
// Extracts all the input variable references for a given `aggExpr`. This result will be used | ||
// to split aggregation into small functions. | ||
private def getInputVariableReferences( | ||
context: CodegenContext, | ||
aggregateExpression: Expression, | ||
subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { | ||
// `argSet` collects all the pairs of variable names and their types, the first in the pair is | ||
// a type name and the second is a variable name. | ||
val argSet = mutable.Set[(String, String)]() | ||
val stack = mutable.Stack[Expression](aggregateExpression) | ||
while (stack.nonEmpty) { | ||
stack.pop() match { | ||
case e if subExprs.contains(e) => | ||
val exprCode = subExprs(e) | ||
if (CodegenContext.isJavaIdentifier(exprCode.value)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hey, good news! Thanks for letting me know ;) |
||
argSet += ((context.javaType(e.dataType), exprCode.value)) | ||
} | ||
if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { | ||
argSet += (("boolean", exprCode.isNull)) | ||
} | ||
// Since the children possibly has common expressions, we push them here | ||
stack.pushAll(e.children) | ||
case ref: BoundReference | ||
if context.currentVars != null && context.currentVars(ref.ordinal) != null => | ||
val value = context.currentVars(ref.ordinal).value | ||
val isNull = context.currentVars(ref.ordinal).isNull | ||
if (CodegenContext.isJavaIdentifier(value)) { | ||
argSet += ((context.javaType(ref.dataType), value)) | ||
} | ||
if (CodegenContext.isJavaIdentifier(isNull)) { | ||
argSet += (("boolean", isNull)) | ||
} | ||
case _: BoundReference => | ||
argSet += (("InternalRow", context.INPUT_ROW)) | ||
case e => | ||
stack.pushAll(e.children) | ||
} | ||
} | ||
|
||
argSet.toSet | ||
} | ||
|
||
// Splits aggregate code into small functions because JVMs does not compile too long functions | ||
private def splitAggregateExpressions( | ||
context: CodegenContext, | ||
aggregateExpressions: Seq[Expression], | ||
codes: Seq[String], | ||
subExprs: Map[Expression, SubExprEliminationState], | ||
otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { | ||
aggregateExpressions.zipWithIndex.map { case (aggExpr, i) => | ||
val args = (getInputVariableReferences(context, aggExpr, subExprs) ++ otherArgs).toSeq | ||
|
||
// This method gives up splitting the code if the parameter length goes over | ||
// `maxParamNumInJavaMethod`. | ||
if (args.size <= sqlContext.conf.maxParamNumInJavaMethod) { | ||
val doAggVal = context.freshName(s"doAggregateVal_${aggExpr.prettyName}") | ||
val argList = args.map(a => s"${a._1} ${a._2}").mkString(", ") | ||
val doAggValFuncName = context.addNewFunction(doAggVal, | ||
s""" | ||
| private void $doAggVal($argList) throws java.io.IOException { | ||
| ${codes(i)} | ||
| } | ||
""".stripMargin) | ||
|
||
val inputVariables = args.map(_._2).mkString(", ") | ||
s"$doAggValFuncName($inputVariables);" | ||
} else { | ||
codes(i) | ||
} | ||
} | ||
} | ||
|
||
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
// only have DeclarativeAggregate | ||
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) | ||
|
@@ -269,28 +343,53 @@ case class HashAggregateExec( | |
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions | ||
} | ||
} | ||
ctx.currentVars = bufVars ++ input | ||
|
||
// We need to copy the aggregation buffer to local variables first because each aggregate | ||
// function directly updates the buffer when it finishes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just FYI: we must need local copys from this discussions, too #19865 |
||
val localBufVars = bufVars.zip(updateExpr).map { case (ev, e) => | ||
val isNull = ctx.freshName("localBufIsNull") | ||
val value = ctx.freshName("localBufValue") | ||
val initLocalVars = s""" | ||
| boolean $isNull = ${ev.isNull}; | ||
| ${ctx.javaType(e.dataType)} $value = ${ev.value}; | ||
""".stripMargin | ||
ExprCode(initLocalVars, isNull, value) | ||
} | ||
|
||
val initLocalBufVar = evaluateVariables(localBufVars) | ||
|
||
ctx.currentVars = localBufVars ++ input | ||
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) | ||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) | ||
val effectiveCodes = subExprs.codes.mkString("\n") | ||
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { | ||
boundUpdateExpr.map(_.genCode(ctx)) | ||
} | ||
// aggregate buffer should be updated atomic | ||
val updates = aggVals.zipWithIndex.map { case (ev, i) => | ||
|
||
val evalAndUpdateCodes = aggVals.zipWithIndex.map { case (ev, i) => | ||
s""" | ||
| // evaluate aggregate function | ||
| ${ev.code} | ||
| // update aggregation buffer | ||
| ${bufVars(i).isNull} = ${ev.isNull}; | ||
| ${bufVars(i).value} = ${ev.value}; | ||
""".stripMargin | ||
} | ||
|
||
val updateAggValCode = splitAggregateExpressions( | ||
context = ctx, | ||
aggregateExpressions = boundUpdateExpr, | ||
codes = evalAndUpdateCodes, | ||
subExprs = subExprs.states) | ||
|
||
s""" | ||
| // do aggregate | ||
| // copy aggregation buffer to the local | ||
| $initLocalBufVar | ||
| // common sub-expressions | ||
| $effectiveCodes | ||
| // evaluate aggregate function | ||
| ${evaluateVariables(aggVals)} | ||
| // update aggregation buffer | ||
| ${updates.mkString("\n").trim} | ||
| // process aggregate functions to update aggregation buffer | ||
| ${updateAggValCode.mkString("\n")} | ||
""".stripMargin | ||
} | ||
|
||
|
@@ -825,52 +924,92 @@ case class HashAggregateExec( | |
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input | ||
|
||
val updateRowInRegularHashMap: String = { | ||
ctx.INPUT_ROW = unsafeRowBuffer | ||
// We need to copy the aggregation row buffer to a local row first because each aggregate | ||
// function directly updates the buffer when it finishes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does this matter? We should avoid unnecessary data copy as possible as we can. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need this copy because: #19082 (comment) |
||
val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") | ||
val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" | ||
|
||
ctx.INPUT_ROW = localRowBuffer | ||
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) | ||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) | ||
val effectiveCodes = subExprs.codes.mkString("\n") | ||
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { | ||
boundUpdateExpr.map(_.genCode(ctx)) | ||
} | ||
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => | ||
|
||
val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => | ||
val dt = updateExpr(i).dataType | ||
ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) | ||
val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) | ||
s""" | ||
| // evaluate aggregate function | ||
| ${ev.code} | ||
| // update unsafe row buffer | ||
| $updateColumnCode | ||
""".stripMargin | ||
} | ||
|
||
val updateAggValCode = splitAggregateExpressions( | ||
context = ctx, | ||
aggregateExpressions = boundUpdateExpr, | ||
codes = evalAndUpdateCodes, | ||
subExprs = subExprs.states, | ||
otherArgs = Seq(("InternalRow", unsafeRowBuffer))) | ||
|
||
s""" | ||
|// common sub-expressions | ||
|$effectiveCodes | ||
|// evaluate aggregate function | ||
|${evaluateVariables(unsafeRowBufferEvals)} | ||
|// update unsafe row buffer | ||
|${updateUnsafeRowBuffer.mkString("\n").trim} | ||
| // do aggregate | ||
| // copy aggregation row buffer to the local | ||
| $initLocalRowBuffer | ||
| // common sub-expressions | ||
| $effectiveCodes | ||
| // process aggregate functions to update aggregation buffer | ||
| ${updateAggValCode.mkString("\n")} | ||
""".stripMargin | ||
} | ||
|
||
val updateRowInHashMap: String = { | ||
if (isFastHashMapEnabled) { | ||
ctx.INPUT_ROW = fastRowBuffer | ||
// We need to copy the aggregation row buffer to a local row first because each aggregate | ||
// function directly updates the buffer when it finishes. | ||
val localRowBuffer = ctx.freshName("localFastRowBuffer") | ||
val initLocalRowBuffer = s"InternalRow $localRowBuffer = $fastRowBuffer.copy();" | ||
|
||
ctx.INPUT_ROW = localRowBuffer | ||
val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) | ||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) | ||
val effectiveCodes = subExprs.codes.mkString("\n") | ||
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { | ||
boundUpdateExpr.map(_.genCode(ctx)) | ||
} | ||
val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => | ||
|
||
val evalAndUpdateCodes = fastRowEvals.zipWithIndex.map { case (ev, i) => | ||
val dt = updateExpr(i).dataType | ||
ctx.updateColumn( | ||
fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) | ||
val updateColumnCode = ctx.updateColumn( | ||
fastRowBuffer, dt, i, ev, updateExpr(i).nullable) | ||
s""" | ||
| // evaluate aggregate function | ||
| ${ev.code} | ||
| // update fast row | ||
| $updateColumnCode | ||
""".stripMargin | ||
} | ||
|
||
val updateAggValCode = splitAggregateExpressions( | ||
context = ctx, | ||
aggregateExpressions = boundUpdateExpr, | ||
codes = evalAndUpdateCodes, | ||
subExprs = subExprs.states, | ||
otherArgs = Seq(("InternalRow", fastRowBuffer))) | ||
|
||
// If fast hash map is on, we first generate code to update row in fast hash map, if the | ||
// previous loop up hit fast hash map. Otherwise, update row in regular hash map. | ||
s""" | ||
|if ($fastRowBuffer != null) { | ||
| // copy aggregation row buffer to the local | ||
| $initLocalRowBuffer | ||
| // common sub-expressions | ||
| $effectiveCodes | ||
| // evaluate aggregate function | ||
| ${evaluateVariables(fastRowEvals)} | ||
| // update fast row | ||
| ${updateFastRow.mkString("\n").trim} | ||
| // process aggregate functions to update aggregation buffer | ||
| ${updateAggValCode.mkString("\n")} | ||
|} else { | ||
| $updateRowInRegularHashMap | ||
|} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need add
enum
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @rednaxelafx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enum looks over kill for now